From f41aeb9608d40664e80d44c275988dabafff62ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 7 Feb 2021 14:30:32 +0100 Subject: [PATCH 01/22] consistent naming for dbl-width --- benchmarks/bench_fp_double_width.nim | 36 +++++++++---------- .../assembly/limbs_asm_montred_x86.nim | 4 +-- .../limbs_asm_montred_x86_adx_bmi2.nim | 4 +-- constantine/arithmetic/finite_fields.nim | 4 +-- .../arithmetic/finite_fields_double_width.nim | 20 +++++------ constantine/arithmetic/limbs_montgomery.nim | 10 +++--- .../extension_fields.nim | 24 ++++++------- tests/t_finite_fields_double_width.nim | 8 ++--- 8 files changed, 54 insertions(+), 56 deletions(-) diff --git a/benchmarks/bench_fp_double_width.nim b/benchmarks/bench_fp_double_width.nim index 2315de77d..87ff4f37a 100644 --- a/benchmarks/bench_fp_double_width.nim +++ b/benchmarks/bench_fp_double_width.nim @@ -121,12 +121,12 @@ func random_unsafe(rng: var RngState, a: var FpDbl, Base: typedesc) = for i in 0 ..< aHi.mres.limbs.len: a.limbs2x[aLo.mres.limbs.len+i] = aHi.mres.limbs[i] -proc sumNoReduce(T: typedesc, iters: int) = +proc sumUnred(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) - bench("Addition no reduce", $T, iters): - r.sumNoReduce(a, b) + bench("Addition unreduced", $T, iters): + r.sumUnred(a, b) proc sum(T: typedesc, iters: int) = var r: T @@ -135,12 +135,12 @@ proc sum(T: typedesc, iters: int) = bench("Addition", $T, iters): r.sum(a, b) -proc diffNoReduce(T: typedesc, iters: int) = +proc diffUnred(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) - bench("Substraction no reduce", $T, iters): - r.diffNoReduce(a, b) + bench("Substraction unreduced", $T, iters): + r.diffUnred(a, b) proc diff(T: typedesc, iters: int) = var r: T @@ -154,28 +154,28 @@ proc diff2xNoReduce(T: typedesc, iters: int) = rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x no reduce", $doubleWidth(T), iters): - r.diffNoReduce(a, b) + bench("Substraction 2x unreduced", $doubleWidth(T), iters): + r.diff2xUnred(a, b) proc diff2x(T: typedesc, iters: int) = var r, a, b: doubleWidth(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x", $doubleWidth(T), iters): - r.diff(a, b) + bench("Substraction 2x reduced", $doubleWidth(T), iters): + r.diff2xMod(a, b) -proc mul2xBench*(rLen, aLen, bLen: static int, iters: int) = +proc prod2xBench*(rLen, aLen, bLen: static int, iters: int) = var r: BigInt[rLen] let a = rng.random_unsafe(BigInt[aLen]) let b = rng.random_unsafe(BigInt[bLen]) - bench("Multiplication", $rLen & " <- " & $aLen & " x " & $bLen, iters): + bench("Multiplication 2x", $rLen & " <- " & $aLen & " x " & $bLen, iters): r.prod(a, b) proc square2xBench*(rLen, aLen: static int, iters: int) = var r: BigInt[rLen] let a = rng.random_unsafe(BigInt[aLen]) - bench("Squaring", $rLen & " <- " & $aLen & "²", iters): + bench("Squaring 2x", $rLen & " <- " & $aLen & "²", iters): r.square(a) proc reduce2x*(T: typedesc, iters: int) = @@ -183,18 +183,18 @@ proc reduce2x*(T: typedesc, iters: int) = var t: doubleWidth(T) rng.random_unsafe(t, T) - bench("Reduce 2x-width", $T & " <- " & $doubleWidth(T), iters): - r.reduce(t) + bench("Redc 2x", $T & " <- " & $doubleWidth(T), iters): + r.redc2x(t) proc main() = separator() - sumNoReduce(Fp[BLS12_381], iters = 10_000_000) - diffNoReduce(Fp[BLS12_381], iters = 10_000_000) + sumUnred(Fp[BLS12_381], iters = 10_000_000) + diffUnred(Fp[BLS12_381], iters = 10_000_000) sum(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000) diff2xNoReduce(Fp[BLS12_381], iters = 10_000_000) - mul2xBench(768, 384, 384, iters = 10_000_000) + prod2xBench(768, 384, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000) separator() diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim index 42e74e86a..5f119f756 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim @@ -83,7 +83,7 @@ proc finalSubCanOverflow*( # Montgomery reduction # ------------------------------------------------------------ -macro montyRed_gen[N: static int]( +macro montyRedc2x_gen[N: static int]( r_MR: var array[N, SecretWord], a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], @@ -252,4 +252,4 @@ func montRed_asm*[N: static int]( canUseNoCarryMontyMul: static bool ) = ## Constant-time Montgomery reduction - montyRed_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2x_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim index afc26a15c..746cefc05 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim @@ -35,7 +35,7 @@ static: doAssert UseASM_X86_64 # Montgomery reduction # ------------------------------------------------------------ -macro montyRedx_gen[N: static int]( +macro montyRedc2xx_gen[N: static int]( r_MR: var array[N, SecretWord], a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], @@ -191,4 +191,4 @@ func montRed_asm_adx_bmi2*[N: static int]( canUseNoCarryMontyMul: static bool ) = ## Constant-time Montgomery reduction - montyRedx_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2xx_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 3ed4c1236..044c9205d 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -169,7 +169,7 @@ func sum*(r: var FF, a, b: FF) {.meter.} = overflowed = overflowed or not(r.mres < FF.fieldMod()) discard csub(r.mres, FF.fieldMod(), overflowed) -func sumNoReduce*(r: var FF, a, b: FF) {.meter.} = +func sumUnred*(r: var FF, a, b: FF) {.meter.} = ## Sum ``a`` and ``b`` into ``r`` without reduction discard r.mres.sum(a.mres, b.mres) @@ -183,7 +183,7 @@ func diff*(r: var FF, a, b: FF) {.meter.} = var underflowed = r.mres.diff(a.mres, b.mres) discard cadd(r.mres, FF.fieldMod(), underflowed) -func diffNoReduce*(r: var FF, a, b: FF) {.meter.} = +func diffUnred*(r: var FF, a, b: FF) {.meter.} = ## Substract `b` from `a` and store the result into `r` ## without reduction discard r.mres.diff(a.mres, b.mres) diff --git a/constantine/arithmetic/finite_fields_double_width.nim b/constantine/arithmetic/finite_fields_double_width.nim index a9a94dbf5..ce77fa70e 100644 --- a/constantine/arithmetic/finite_fields_double_width.nim +++ b/constantine/arithmetic/finite_fields_double_width.nim @@ -35,18 +35,20 @@ template doubleWidth*(T: typedesc[Fp]): typedesc = func `==`*(a, b: FpDbl): SecretBool = a.limbs2x == b.limbs2x -func mulNoReduce*(r: var FpDbl, a, b: Fp) = +func prod2x*(r: var FpDbl, a, b: Fp) = + ## Double-precision multiplication ## Store the product of ``a`` by ``b`` into ``r`` r.limbs2x.prod(a.mres.limbs, b.mres.limbs) -func squareNoReduce*(r: var FpDbl, a: Fp) = +func square2x*(r: var FpDbl, a: Fp) = + ## Double-precision squaring ## Store the square of ``a`` into ``r`` r.limbs2x.square(a.mres.limbs) -func reduce*(r: var Fp, a: FpDbl) = - ## Reduce a double-width field element into r +func redc2x*(r: var Fp, a: FpDbl) = + ## Reduce a double-precision field element into r const N = r.mres.limbs.len - montyRed( + montyRedc2x( r.mres.limbs, a.limbs2x, Fp.C.Mod.limbs, @@ -54,11 +56,11 @@ func reduce*(r: var Fp, a: FpDbl) = Fp.canUseNoCarryMontyMul() ) -func diffNoReduce*(r: var FpDbl, a, b: FpDbl) = +func diff2xUnred*(r: var FpDbl, a, b: FpDbl) = ## Double-width substraction without reduction discard r.limbs2x.diff(a.limbs2x, b.limbs2x) -func diff*(r: var FpDbl, a, b: FpDbl) = +func diff2xMod*(r: var FpDbl, a, b: FpDbl) = ## Double-width modular substraction when UseASM_X86_64: sub2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) @@ -73,9 +75,5 @@ func diff*(r: var FpDbl, a, b: FpDbl) = addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) underflowed.ccopy(r.limbs2x[i+N], sum) -func `-=`*(a: var FpDbl, b: FpDbl) = - ## Double-width modular substraction - a.diff(a, b) - {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index e45964c2c..26eff233a 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -281,7 +281,7 @@ func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) {.used.}= # Montgomery Reduction # ------------------------------------------------------------ -func montyRed_CIOS[N: static int]( +func montyRedc2x_CIOS[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], @@ -343,7 +343,7 @@ func montyRed_CIOS[N: static int]( discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) r = res -func montyRed_Comba[N: static int]( +func montyRedc2x_Comba[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], @@ -459,7 +459,7 @@ func montySquare*(r: var Limbs, a, M: Limbs, # montyMul_FIPS(r, a, a, M, m0ninv) # TODO upstream, using Limbs[N] breaks semcheck -func montyRed*[N: static int]( +func montyRedc2x*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], @@ -474,8 +474,8 @@ func montyRed*[N: static int]( # TODO: Assembly faster than GCC but slower than Clang montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) else: - montyRed_CIOS(r, a, M, m0ninv) - # montyRed_Comba(r, a, M, m0ninv) + montyRedc2x_CIOS(r, a, M, m0ninv) + # montyRedc2x_Comba(r, a, M, m0ninv) func redc*(r: var Limbs, a, one, M: Limbs, m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index edd2acd58..bf73397cc 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -410,25 +410,25 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) = var d {.noInit.}: doubleWidth(typeof(r.c0)) const msbSet = r.c0.typeof.canUseNoCarryMontyMul() - a0b0.mulNoReduce(a.c0, b.c0) # 44 cycles - cumul 44 - a1b1.mulNoReduce(a.c1, b.c1) # 44 cycles - cumul 88 + a0b0.prod2x(a.c0, b.c0) # 44 cycles - cumul 44 + a1b1.prod2x(a.c1, b.c1) # 44 cycles - cumul 88 when msbSet: r.c0.sum(a.c0, a.c1) r.c1.sum(b.c0, b.c1) else: - r.c0.sumNoReduce(a.c0, a.c1) # 5 cycles - cumul 93 - r.c1.sumNoReduce(b.c0, b.c1) # 5 cycles - cumul 98 + r.c0.sumUnred(a.c0, a.c1) # 5 cycles - cumul 93 + r.c1.sumUnred(b.c0, b.c1) # 5 cycles - cumul 98 # aliasing: a and b unneeded now - d.mulNoReduce(r.c0, r.c1) # 44 cycles - cumul 142 + d.prod2x(r.c0, r.c1) # 44 cycles - cumul 142 when msbSet: - d -= a0b0 - d -= a1b1 + d.diff2xMod(d, a0b0) + d.diff2xMod(d, a1b1) else: - d.diffNoReduce(d, a0b0) # 11 cycles - cumul 153 - d.diffNoReduce(d, a1b1) # 11 cycles - cumul 164 - a0b0.diff(a0b0, a1b1) # 19 cycles - cumul 183 - r.c0.reduce(a0b0) # 50 cycles - cumul 233 - r.c1.reduce(d) # 50 cycles - cumul 288 + d.diff2xUnred(d, a0b0) # 11 cycles - cumul 153 + d.diff2xUnred(d, a1b1) # 11 cycles - cumul 164 + a0b0.diff2xMod(a0b0, a1b1) # 19 cycles - cumul 183 + r.c0.redc2x(a0b0) # 50 cycles - cumul 233 + r.c1.redc2x(d) # 50 cycles - cumul 288 # Single-width [3 Mul, 2 Add, 3 Sub] # 3*88 + 2*14 + 3*14 = 334 theoretical cycles diff --git a/tests/t_finite_fields_double_width.nim b/tests/t_finite_fields_double_width.nim index 0ca027ae3..8435022b5 100644 --- a/tests/t_finite_fields_double_width.nim +++ b/tests/t_finite_fields_double_width.nim @@ -33,8 +33,8 @@ template mulTest(rng_gen: untyped): untyped = var tmpDbl{.noInit.}: FpDbl[C] r_fp.prod(a, b) - tmpDbl.mulNoReduce(a, b) - r_fpDbl.reduce(tmpDbl) + tmpDbl.prod2x(a, b) + r_fpDbl.redc2x(tmpDbl) doAssert bool(r_fp == r_fpDbl) @@ -44,8 +44,8 @@ template sqrTest(rng_gen: untyped): untyped = var mulDbl{.noInit.}, sqrDbl{.noInit.}: FpDbl[C] - mulDbl.mulNoReduce(a, a) - sqrDbl.squareNoReduce(a) + mulDbl.prod2x(a, a) + sqrDbl.square2x(a) doAssert bool(mulDbl == sqrDbl) From 6f56899261967e8baaaec5750bc301f55b845f15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 7 Feb 2021 15:42:29 +0100 Subject: [PATCH 02/22] Isolate double-width Fp2 mul --- constantine/tower_field_extensions/README.md | 10 ++ .../extension_fields.nim | 132 ++++++++++++------ 2 files changed, 99 insertions(+), 43 deletions(-) diff --git a/constantine/tower_field_extensions/README.md b/constantine/tower_field_extensions/README.md index baa5463fd..7b5927b84 100644 --- a/constantine/tower_field_extensions/README.md +++ b/constantine/tower_field_extensions/README.md @@ -87,6 +87,16 @@ From Ben Edgington, https://hackmd.io/@benjaminion/bls12-381 Jean-Luc Beuchat and Jorge Enrique González Díaz and Shigeo Mitsunari and Eiji Okamoto and Francisco Rodríguez-Henríquez and Tadanori Teruya, 2010\ https://eprint.iacr.org/2010/354 +- Faster Explicit Formulas for Computing Pairings over Ordinary Curves\ + Diego F. Aranha and Koray Karabina and Patrick Longa and Catherine H. Gebotys and Julio López, 2010\ + https://eprint.iacr.org/2010/526.pdf\ + https://www.iacr.org/archive/eurocrypt2011/66320047/66320047.pdf + +- Efficient Implementation of Bilinear Pairings on ARM Processors + Gurleen Grewal, Reza Azarderakhsh, + Patrick Longa, Shi Hu, and David Jao, 2012 + https://eprint.iacr.org/2012/408.pdf + - Choosing and generating parameters for low level pairing implementation on BN curves\ Sylvain Duquesne and Nadia El Mrabet and Safia Haloui and Franck Rondepierre, 2015\ https://eprint.iacr.org/2015/1212 diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index bf73397cc..3cf0d7569 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -47,11 +47,11 @@ template c1*(a: ExtensionField): auto = template c2*(a: CubicExt): auto = a.coords[2] -template `c0=`*(a: ExtensionField, v: auto) = +template `c0=`*(a: var ExtensionField, v: auto) = a.coords[0] = v -template `c1=`*(a: ExtensionField, v: auto) = +template `c1=`*(a: var ExtensionField, v: auto) = a.coords[1] = v -template `c2=`*(a: CubicExt, v: auto) = +template `c2=`*(a: var CubicExt, v: auto) = a.coords[2] = v template C*(E: type ExtensionField): Curve = @@ -304,6 +304,87 @@ func prod*(r: var ExtensionField, a: ExtensionField, b: static int) = {.pop.} # inline +# ############################################################ +# # +# Lazy reduced extension fields # +# # +# ############################################################ + +type + QuadraticExt2x[F] = object + ## Quadratic Extension field for lazy reduced fields + coords: array[2, F] + + CubicExt2x[F] = object + ## Cubic Extension field for lazy reduced fields + coords: array[3, F] + + ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F] + +template doubleWidth(T: typedesc[QuadraticExt]): typedesc = + QuadraticExt2x[doubleWidth(T.F)] + +template doubleWidth(T: typedesc[CubicExt]): typedesc = + CubicExt2x[doubleWidth(T.F)] + +template C(E: type ExtensionField2x): Curve = + E.F.C + +template c0(a: ExtensionField2x): auto = + a.coords[0] +template c1(a: ExtensionField2x): auto = + a.coords[1] +template c2(a: CubicExt2x): auto = + a.coords[2] + +template `c0=`(a: var ExtensionField2x, v: auto) = + a.coords[0] = v +template `c1=`(a: var ExtensionField2x, v: auto) = + a.coords[1] = v +template `c2=`(a: var CubicExt2x, v: auto) = + a.coords[2] = v + +# ############################################################ +# # +# Quadratic extensions - Lazy Reductions # +# # +# ############################################################ + +func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = + ## Double-width unreduced multiplication + # r and a or b cannot alias + + mixin fromComplexExtension + static: doAssert a.fromComplexExtension() + + var d {.noInit.}: typeof(r.c0) + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + const msbSet = a.c0.typeof.canUseNoCarryMontyMul() + + r.c0.prod2x(a.c0, b.c0) # r0 = a0 b0 + d.prod2x(a.c1, b.c1) # d = a1 b1 + when msbSet: + t0.sum(a.c0, a.c1) + t1.sum(b.c0, b.c1) + else: + t0.sumUnred(a.c0, a.c1) + t1.sumUnred(b.c0, b.c1) + r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) + when msbSet: + r.c1.diff2xMod(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 + r.c1.diff2xMod(r.c1, d) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 + else: + r.c1.diff2xUnred(r.c1, r.c0) + r.c1.diff2xUnred(r.c1, d) + r.c0.diff2xMod(r.c0, d) # r0 = a0 b0 - a1 b1 + +# ############################################################ +# # +# Cubic extensions - Lazy Reductions # +# # +# ############################################################ + + # ############################################################ # # # Quadratic extensions # @@ -386,11 +467,7 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension static: doAssert r.fromComplexExtension() - # TODO: GCC is adding an unexplainable 30 cycles tax to this function (~10% slow down) - # for seemingly no reason - - when false: # Single-width implementation - BLS12-381 - # Clang 348 cycles on i9-9980XE @3.9 GHz + when false: # Single-width implementation var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) a0b0.prod(a.c0, b.c0) # [1 Mul] a1b1.prod(a.c1, b.c1) # [2 Mul] @@ -403,43 +480,12 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) = r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub] r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub] r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub] - else: # Double-width implementation with lazy reduction # Clang 341 cycles on i9-9980XE @3.9 GHz - var a0b0 {.noInit.}, a1b1 {.noInit.}: doubleWidth(typeof(r.c0)) - var d {.noInit.}: doubleWidth(typeof(r.c0)) - const msbSet = r.c0.typeof.canUseNoCarryMontyMul() - - a0b0.prod2x(a.c0, b.c0) # 44 cycles - cumul 44 - a1b1.prod2x(a.c1, b.c1) # 44 cycles - cumul 88 - when msbSet: - r.c0.sum(a.c0, a.c1) - r.c1.sum(b.c0, b.c1) - else: - r.c0.sumUnred(a.c0, a.c1) # 5 cycles - cumul 93 - r.c1.sumUnred(b.c0, b.c1) # 5 cycles - cumul 98 - # aliasing: a and b unneeded now - d.prod2x(r.c0, r.c1) # 44 cycles - cumul 142 - when msbSet: - d.diff2xMod(d, a0b0) - d.diff2xMod(d, a1b1) - else: - d.diff2xUnred(d, a0b0) # 11 cycles - cumul 153 - d.diff2xUnred(d, a1b1) # 11 cycles - cumul 164 - a0b0.diff2xMod(a0b0, a1b1) # 19 cycles - cumul 183 - r.c0.redc2x(a0b0) # 50 cycles - cumul 233 - r.c1.redc2x(d) # 50 cycles - cumul 288 - - # Single-width [3 Mul, 2 Add, 3 Sub] - # 3*88 + 2*14 + 3*14 = 334 theoretical cycles - # 348 measured - # Double-Width - # 288 theoretical cycles - # 329 measured - # Unexplained 40 cycles diff between theo and measured - # and unexplained 30 cycles between Clang and GCC - # - Function calls? - # - push/pop stack? + var d {.noInit.}: doubleWidth(typeof(r)) + d.prod2x_complex(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) func mul_sparse_complex_by_0y( r: var QuadraticExt, a: QuadraticExt, From 180c8396a7da4d0195f7668e6ebe5c2e08ca4f2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 7 Feb 2021 16:09:16 +0100 Subject: [PATCH 03/22] Implement double-width complex multiplication --- .../extension_fields.nim | 76 +++++++++++++------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 3cf0d7569..7de51c997 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -350,8 +350,11 @@ template `c2=`(a: var CubicExt2x, v: auto) = # # # ############################################################ +# Commutative ring implementation for complex quadratic extension fields +# ---------------------------------------------------------------------- + func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = - ## Double-width unreduced multiplication + ## Double-width unreduced complex multiplication # r and a or b cannot alias mixin fromComplexExtension @@ -378,6 +381,26 @@ func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = r.c1.diff2xUnred(r.c1, d) r.c0.diff2xMod(r.c0, d) # r0 = a0 b0 - a1 b1 +func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = + ## Double-width unreduced complex squaring + + mixin fromComplexExtension + static: doAssert a.fromComplexExtension() + + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + const msbSet = a.c0.typeof.canUseNoCarryMontyMul() + + when msbSet: + t0.double(a.c1) + t1.sum(a.c0, a.c1) + else: + t0.sumUnred(a.c1, a.c1) + t1.sum(a.c0, a.c1) + + r.c1.prod2x(t0, a.c0) # r1 = 2a0a1 + t0.diff(a.c0, a.c1) + r.c0.prod2x(t0, t1) # r0 = (a0 + a1)(a0 - a1) + # ############################################################ # # # Cubic extensions - Lazy Reductions # @@ -467,25 +490,18 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension static: doAssert r.fromComplexExtension() - when false: # Single-width implementation - var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) - a0b0.prod(a.c0, b.c0) # [1 Mul] - a1b1.prod(a.c1, b.c1) # [2 Mul] - - r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add] - r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add] - # aliasing: a and b unneeded now - r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary - - r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub] - r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub] - r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub] - else: # Double-width implementation with lazy reduction - # Clang 341 cycles on i9-9980XE @3.9 GHz - var d {.noInit.}: doubleWidth(typeof(r)) - d.prod2x_complex(a, b) - r.c0.redc2x(d.c0) - r.c1.redc2x(d.c1) + var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) + a0b0.prod(a.c0, b.c0) # [1 Mul] + a1b1.prod(a.c1, b.c1) # [2 Mul] + + r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add] + r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add] + # aliasing: a and b unneeded now + r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary + + r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub] + r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub] + r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub] func mul_sparse_complex_by_0y( r: var QuadraticExt, a: QuadraticExt, @@ -704,22 +720,34 @@ func invImpl(r: var QuadraticExt, a: QuadraticExt) = # Exported quadratic symbols # ------------------------------------------------------------------- -{.push inline.} - func square*(r: var QuadraticExt, a: QuadraticExt) = mixin fromComplexExtension when r.fromComplexExtension(): - r.square_complex(a) + when true: + r.square_complex(a) + else: # slower + var d {.noInit.}: doubleWidth(typeof(r)) + d.square2x_complex(a) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) else: r.square_generic(a) func prod*(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension when r.fromComplexExtension(): - r.prod_complex(a, b) + when false: + r.prod_complex(a, b) + else: # faster + var d {.noInit.}: doubleWidth(typeof(r)) + d.prod2x_complex(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) else: r.prod_generic(a, b) +{.push inline.} + func inv*(r: var QuadraticExt, a: QuadraticExt) = ## Compute the multiplicative inverse of ``a`` ## From acc4afe7a09fe6f6f9b724504afabfd0fb8c36bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 7 Feb 2021 20:19:43 +0100 Subject: [PATCH 04/22] Lay out Fp4 double-width mul --- .../arithmetic/finite_fields_double_width.nim | 21 +- .../extension_fields.nim | 212 +++++++++++++++++- 2 files changed, 224 insertions(+), 9 deletions(-) diff --git a/constantine/arithmetic/finite_fields_double_width.nim b/constantine/arithmetic/finite_fields_double_width.nim index ce77fa70e..7253bde04 100644 --- a/constantine/arithmetic/finite_fields_double_width.nim +++ b/constantine/arithmetic/finite_fields_double_width.nim @@ -24,7 +24,7 @@ type FpDbl*[C: static Curve] = object # We directly work with double the number of limbs limbs2x*: matchingLimbs2x(C) -template doubleWidth*(T: typedesc[Fp]): typedesc = +template doubleWidth*(T: type Fp): type = ## Return the double-width type matching with Fp FpDbl[T.C] @@ -75,5 +75,24 @@ func diff2xMod*(r: var FpDbl, a, b: FpDbl) = addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) underflowed.ccopy(r.limbs2x[i+N], sum) +func sum2xUnred*(r: var FpDbl, a, b: FpDbl) = + ## Double-width addition without reduction + discard r.limbs2x.sum(a.limbs2x, b.limbs2x) + +func sum2xMod*(r: var FpDbl, a, b: FpDbl) = + ## Double-width modular substraction + when false: # TODO: UseASM_X86_64: + sum2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) + else: + var overflowed = SecretBool r.limbs2x.sum(a.limbs2x, b.limbs2x) + + const N = r.limbs2x.len div 2 + const M = FpDbl.C.Mod + var borrow = Borrow(0) + var diff: SecretWord + for i in 0 ..< N: + subB(borrow, diff, r.limbs2x[i+N], M.limbs[i], borrow) + overflowed.ccopy(r.limbs2x[i+N], diff) + {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 7de51c997..d40654ff7 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -321,11 +321,21 @@ type ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F] -template doubleWidth(T: typedesc[QuadraticExt]): typedesc = - QuadraticExt2x[doubleWidth(T.F)] - -template doubleWidth(T: typedesc[CubicExt]): typedesc = - CubicExt2x[doubleWidth(T.F)] +template doubleWidth(T: type ExtensionField): type = + # For now naive unrolling, recursive template don't match + # and I don't want to deal with types in macros + when T is QuadraticExt: + when T.F is QuadraticExt: # Fp4Dbl + QuadraticExt2x[QuadraticExt2x[doubleWidth(T.F.F)]] + elif T.F is Fp: # Fp2Dbl + QuadraticExt2x[doubleWidth(T.F)] + elif T is CubicExt: + when T.F is QuadraticExt: # Fp6Dbl + CubicExt2x[QuadraticExt2x[doubleWidth(T.F.F)]] + +func has2extraBits(E: type ExtensionField): bool = + ## We construct extensions only on Fp (and not Fr) + canUseNoCarryMontySquare(Fp[E.C]) template C(E: type ExtensionField2x): Curve = E.F.C @@ -344,6 +354,144 @@ template `c1=`(a: var ExtensionField2x, v: auto) = template `c2=`(a: var CubicExt2x, v: auto) = a.coords[2] = v +# Abelian group +# ------------------------------------------------------------------- + +func sumUnred(r: var ExtensionField, a, b: ExtensionField) = + ## Sum ``a`` and ``b`` into ``r`` + staticFor i, 0, a.coords.len: + r.coords[i].sumUnred(a.coords[i], b.coords[i]) + +func diff2xUnred(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-width substraction without reduction + staticFor i, 0, a.coords.len: + r.coords[i].diff2xUnred(a.coords[i], b.coords[i]) + +func diff2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-width modular substraction + staticFor i, 0, a.coords.len: + r.coords[i].diff2xMod(a.coords[i], b.coords[i]) + +func sum2xUnred(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-width addition without reduction + staticFor i, 0, a.coords.len: + r.coords[i].sum2xUnred(a.coords[i], b.coords[i]) + +func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-width modular addition + staticFor i, 0, a.coords.len: + r.coords[i].sum2xMod(a.coords[i], b.coords[i]) + +# Reductions +# ------------------------------------------------------------------- + +func redc2x(r: var ExtensionField, a: ExtensionField2x) = + ## Reduction + staticFor i, 0, a.coords.len: + r.coords[i].redc2x(a.coords[i]) + +# Multiplication by a small integer known at compile-time +# ------------------------------------------------------------------- + +func prod( + r {.noAlias.}: var ExtensionField2x, + a {.noAlias.}: ExtensionField2x, b: static int) = + ## Multiplication by a small integer known at compile-time + ## Requires no aliasing + const negate = b < 0 + const b = if negate: -b + else: b + when negate: + r.diff2xMod(typeof(a)(), a) + when b == 0: + r.setZero() + elif b == 1: + return + elif b == 2: + r.sum2xMod(a, a) + elif b == 3: + r.sum2xMod(a, a) + r.sum2xMod(a, r) + elif b == 4: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + elif b == 5: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 6: + r.sum2xMod(a, a) + let t2 = r + r.sum2xMod(r, r) # 4 + r.sum2xMod(t, t2) + elif b == 7: + r.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + r.sum2xMod(r, r) + r.diff2xMod(r, a) + elif b == 8: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) + elif b == 9: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, a) + elif b == 10: + a.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + elif b == 11: + a.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 12: + a.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + let t4 = a + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, t4) + else: + {.error: "Multiplication by this small int not implemented".} + +# NonResidue +# ---------------------------------------------------------------------- + +func prod2x[C: static Curve]( + r {.noalias.}: var QuadraticExt2x[FpDbl[C]], + a {.noalias.}: QuadraticExt2x[FpDbl[C]], + _: type NonResidue) {.inline.} = + ## Multiplication by non-residue + ## ! no aliasing! + const complex = C.getNonResidueFp() == -1 + const U = C.getNonResidueFp2()[0] + const V = C.getNonResidueFp2()[1] + const Beta {.used.} = C.getNonResidueFp() + + when complex and U == 1 and V == 1: + r.c0.diff2xMod(a.c0, a.c1) + r.c1.sum2xMod(a.c0, a.c1) + else: + # ξ = u + v x + # and x² = β + # + # (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x² + # => u c0 + β v c1 + (v c0 + u c1) x + var t {.noInit.}: FpDbl[C] + + t.prod(a.c1, V) + r.c0.prod(t, Beta) + t.prod(a.c0, U) + r.c0.sum2xMod(t) # r0 = u c0 + β v c1 + + r.c1.prod(a.c0, V) + t.prod(a.c1, U) + r.c1.sum2xMod(t) # r1 = v c0 + u c1 + # ############################################################ # # # Quadratic extensions - Lazy Reductions # @@ -388,9 +536,9 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = static: doAssert a.fromComplexExtension() var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) - const msbSet = a.c0.typeof.canUseNoCarryMontyMul() - when msbSet: + # Require 2 extra bits + when QuadraticExt.has2extraBits(): t0.double(a.c1) t1.sum(a.c0, a.c1) else: @@ -401,6 +549,48 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = t0.diff(a.c0, a.c1) r.c0.prod2x(t0, t1) # r0 = (a0 + a1)(a0 - a1) +func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = + mixin fromComplexExtension + when a.fromComplexExtension(): + r.prod2x_complex(a, b) + else: + {.error: "Not Implementated".} + +# Commutative ring implementation for generic quadratic extension fields +# ---------------------------------------------------------------------- +# +# Some sparse functions, reconstruct a Fp4 from disjoint pieces +# to limit copies, we provide versions with disjoint elements + +func prod2x_disjoint[Fdbl, F]( + r: var QuadraticExt2x[FDbl], + a: QuadraticExt[F], + b0, b1: F + ) = + ## Return a * (b0, b1) in r + static: doAssert Fdbl is doubleWidth(F) + + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) # Double-width + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width + + # Require 2 extra bits + v0.prod2x(a.c0, b0) + v1.prod2x(a.c1, b1) + when F.has2extraBits(): + t0.sumUnred(b0, b1) + t1.sumUnred(a.c0, a.c1) + else: + t0.sum(b0, b1) + t1.sum(a.c0, a.c1) + + r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) , and at most 2 extra bits + r.c1.diff2xMod(r.c1, v0) # r1 = (a0 + a1)(b0 + b1) - a0b0 , and at most 1 extra bit + r.c1.diff2xMod(r.c1, v1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1, and 0 extra bit + + # TODO: This is correct modulo p, but we have some extra bits still here. + r.c0.prod2x(v1, NonResidue) # r0 = β a1 b1 + r.c0.sum2xUnred(r.c0, v0) # r0 = a0 b0 + β a1 b1 + # ############################################################ # # # Cubic extensions - Lazy Reductions # @@ -744,7 +934,13 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) else: - r.prod_generic(a, b) + when true: # typeof(r.c0) is Fp: + r.prod_generic(a, b) + else: # Deactivated r.c0 is correct modulo p but >= p + var d {.noInit.}: doubleWidth(typeof(r)) + d.prod2x_disjoint(a, b.c0, b.c1) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) {.push inline.} From e07a94715fab6ee77b9e7bc392cf8c635ea2550c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 7 Feb 2021 23:53:46 +0100 Subject: [PATCH 05/22] Off by p in square Fp4 as well :/ --- constantine.nimble | 1 + .../arithmetic/finite_fields_double_width.nim | 68 +++++++++ .../extension_fields.nim | 144 +++++++++++++----- tests/t_fp4.nim | 129 ++++++++++++++++ tests/t_fp_tower_template.nim | 7 +- 5 files changed, 311 insertions(+), 38 deletions(-) create mode 100644 tests/t_fp4.nim diff --git a/constantine.nimble b/constantine.nimble index bc42e1434..bb3d0bb90 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -50,6 +50,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ # ---------------------------------------------------------- ("tests/t_fp2.nim", false), ("tests/t_fp2_sqrt.nim", false), + ("tests/t_fp4.nim", false), ("tests/t_fp6_bn254_snarks.nim", false), ("tests/t_fp6_bls12_377.nim", false), ("tests/t_fp6_bls12_381.nim", false), diff --git a/constantine/arithmetic/finite_fields_double_width.nim b/constantine/arithmetic/finite_fields_double_width.nim index 7253bde04..a2d59de6a 100644 --- a/constantine/arithmetic/finite_fields_double_width.nim +++ b/constantine/arithmetic/finite_fields_double_width.nim @@ -35,6 +35,9 @@ template doubleWidth*(T: type Fp): type = func `==`*(a, b: FpDbl): SecretBool = a.limbs2x == b.limbs2x +func setZero*(a: var FpDbl) = + a.limbs2x.setZero() + func prod2x*(r: var FpDbl, a, b: Fp) = ## Double-precision multiplication ## Store the product of ``a`` by ``b`` into ``r`` @@ -94,5 +97,70 @@ func sum2xMod*(r: var FpDbl, a, b: FpDbl) = subB(borrow, diff, r.limbs2x[i+N], M.limbs[i], borrow) overflowed.ccopy(r.limbs2x[i+N], diff) +func prod2x*( + r {.noAlias.}: var FpDbl, + a {.noAlias.}: FpDbl, b: static int) = + ## Multiplication by a small integer known at compile-time + ## Requires no aliasing + const negate = b < 0 + const b = if negate: -b + else: b + when negate: + r.diff2xMod(typeof(a)(), a) + when b == 0: + r.setZero() + elif b == 1: + r = a + elif b == 2: + r.sum2xMod(a, a) + elif b == 3: + r.sum2xMod(a, a) + r.sum2xMod(a, r) + elif b == 4: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + elif b == 5: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 6: + r.sum2xMod(a, a) + let t2 = r + r.sum2xMod(r, r) # 4 + r.sum2xMod(t, t2) + elif b == 7: + r.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + r.sum2xMod(r, r) + r.diff2xMod(r, a) + elif b == 8: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) + elif b == 9: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, a) + elif b == 10: + a.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + elif b == 11: + a.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 12: + a.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + let t4 = a + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, t4) + else: + {.error: "Multiplication by this small int not implemented".} + {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index d40654ff7..a35229e4e 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -333,6 +333,10 @@ template doubleWidth(T: type ExtensionField): type = when T.F is QuadraticExt: # Fp6Dbl CubicExt2x[QuadraticExt2x[doubleWidth(T.F.F)]] +func has1extraBit(E: type ExtensionField): bool = + ## We construct extensions only on Fp (and not Fr) + canUseNoCarryMontyMul(Fp[E.C]) + func has2extraBits(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) canUseNoCarryMontySquare(Fp[E.C]) @@ -354,6 +358,14 @@ template `c1=`(a: var ExtensionField2x, v: auto) = template `c2=`(a: var CubicExt2x, v: auto) = a.coords[2] = v +# Initialization +# ------------------------------------------------------------------- + +func setZero*(a: var ExtensionField2x) = + ## Set ``a`` to 0 in the extension field + staticFor i, 0, a.coords.len: + a.coords[i].setZero() + # Abelian group # ------------------------------------------------------------------- @@ -401,12 +413,13 @@ func prod( const negate = b < 0 const b = if negate: -b else: b + when negate: r.diff2xMod(typeof(a)(), a) when b == 0: r.setZero() elif b == 1: - return + r = a elif b == 2: r.sum2xMod(a, a) elif b == 3: @@ -439,18 +452,18 @@ func prod( r.sum2xMod(r, r) # 8 r.sum2xMod(r, a) elif b == 10: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) r.sum2xMod(r, a) # 5 r.sum2xMod(r, r) elif b == 11: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) r.sum2xMod(r, a) # 5 r.sum2xMod(r, r) r.sum2xMod(r, a) elif b == 12: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) # 4 let t4 = a r.sum2xMod(r, r) # 8 @@ -461,6 +474,12 @@ func prod( # NonResidue # ---------------------------------------------------------------------- +func prod2x(r: var FpDbl, a: FpDbl, _: type NonResidue){.inline.} = + ## Multiply an element of 𝔽p by the quadratic non-residue + ## chosen to construct 𝔽p2 + static: doAssert FpDbl.C.getNonResidueFp() != -1, "𝔽p2 should be specialized for complex extension" + r.prod2x(a, FpDbl.C.getNonResidueFp()) + func prod2x[C: static Curve]( r {.noalias.}: var QuadraticExt2x[FpDbl[C]], a {.noalias.}: QuadraticExt2x[FpDbl[C]], @@ -483,14 +502,15 @@ func prod2x[C: static Curve]( # => u c0 + β v c1 + (v c0 + u c1) x var t {.noInit.}: FpDbl[C] - t.prod(a.c1, V) - r.c0.prod(t, Beta) - t.prod(a.c0, U) - r.c0.sum2xMod(t) # r0 = u c0 + β v c1 + r.c0.prod2x(a.c0, U) + when V == 1 and Beta == -1: # Case BN254_Snarks + r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1 + else: + {.error: "Unimplemented".} - r.c1.prod(a.c0, V) - t.prod(a.c1, U) - r.c1.sum2xMod(t) # r1 = v c0 + u c1 + r.c1.prod2x(a.c0, V) + t.prod2x(a.c1, U) + r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1 # ############################################################ # # @@ -508,26 +528,25 @@ func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = mixin fromComplexExtension static: doAssert a.fromComplexExtension() - var d {.noInit.}: typeof(r.c0) + var D {.noInit.}: typeof(r.c0) var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) - const msbSet = a.c0.typeof.canUseNoCarryMontyMul() r.c0.prod2x(a.c0, b.c0) # r0 = a0 b0 - d.prod2x(a.c1, b.c1) # d = a1 b1 - when msbSet: - t0.sum(a.c0, a.c1) - t1.sum(b.c0, b.c1) - else: + D.prod2x(a.c1, b.c1) # d = a1 b1 + when QuadraticExt.has1extraBit(): t0.sumUnred(a.c0, a.c1) t1.sumUnred(b.c0, b.c1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b.c0, b.c1) r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) - when msbSet: - r.c1.diff2xMod(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - r.c1.diff2xMod(r.c1, d) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 + when QuadraticExt.has1extraBit(): + r.c1.diff2xUnred(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 + r.c1.diff2xUnred(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 else: - r.c1.diff2xUnred(r.c1, r.c0) - r.c1.diff2xUnred(r.c1, d) - r.c0.diff2xMod(r.c0, d) # r0 = a0 b0 - a1 b1 + r.c1.diff2xMod(r.c1, r.c0) + r.c1.diff2xMod(r.c1, D) + r.c0.diff2xMod(r.c0, D) # r0 = a0 b0 - a1 b1 func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = ## Double-width unreduced complex squaring @@ -539,10 +558,10 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = # Require 2 extra bits when QuadraticExt.has2extraBits(): - t0.double(a.c1) + t0.sumUnred(a.c1, a.c1) t1.sum(a.c0, a.c1) else: - t0.sumUnred(a.c1, a.c1) + t0.double(a.c1) t1.sum(a.c0, a.c1) r.c1.prod2x(t0, a.c0) # r1 = 2a0a1 @@ -556,26 +575,44 @@ func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = else: {.error: "Not Implementated".} +func square2x_disjoint[Fdbl, F]( + r: var QuadraticExt2x[FDbl], + a0, a1: F) + +func square2x(r: var QuadraticExt2x, a: QuadraticExt) = + mixin fromComplexExtension, square2x_disjoint + when a.fromComplexExtension(): + r.square2x_complex(a) + else: + r.square2x_disjoint(a.c0, a.c1) + # Commutative ring implementation for generic quadratic extension fields # ---------------------------------------------------------------------- # # Some sparse functions, reconstruct a Fp4 from disjoint pieces # to limit copies, we provide versions with disjoint elements +# prod2x_disjoint: +# - 2 products in mul_sparse_by_line_xyz000 (Fp4) +# - 2 products in mul_sparse_by_line_xy000z (Fp4) +# - mul_by_line_xy0 in mul_sparse_by_line_xy00z0 (Fp6) +# +# square2x_disjoint: +# - cyclotomic square in Fp2 -> Fp6 -> Fp12 towering +# needs Fp4 as special case func prod2x_disjoint[Fdbl, F]( r: var QuadraticExt2x[FDbl], a: QuadraticExt[F], - b0, b1: F - ) = + b0, b1: F) = ## Return a * (b0, b1) in r static: doAssert Fdbl is doubleWidth(F) - var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) # Double-width + var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-width var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width # Require 2 extra bits - v0.prod2x(a.c0, b0) - v1.prod2x(a.c1, b1) + V0.prod2x(a.c0, b0) + V1.prod2x(a.c1, b1) when F.has2extraBits(): t0.sumUnred(b0, b1) t1.sumUnred(a.c0, a.c1) @@ -584,12 +621,35 @@ func prod2x_disjoint[Fdbl, F]( t1.sum(a.c0, a.c1) r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) , and at most 2 extra bits - r.c1.diff2xMod(r.c1, v0) # r1 = (a0 + a1)(b0 + b1) - a0b0 , and at most 1 extra bit - r.c1.diff2xMod(r.c1, v1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1, and 0 extra bit + r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 , and at most 1 extra bit + r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1, and 0 extra bit # TODO: This is correct modulo p, but we have some extra bits still here. - r.c0.prod2x(v1, NonResidue) # r0 = β a1 b1 - r.c0.sum2xUnred(r.c0, v0) # r0 = a0 b0 + β a1 b1 + r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 + r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 + +func square2x_disjoint[Fdbl, F]( + r: var QuadraticExt2x[FDbl], + a0, a1: F) = + ## Return (a0, a1)² in r + var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-width + var t {.noInit.}: F # Single-width + + # TODO: which is the best formulation? 3 squarings or 2 Mul? + # It seems like the higher the tower the better squarings are + # So for Fp12 = 2xFp6, prefer squarings. + V0.square2x(a0) + V1.square2x(a1) + t.sum(a0, a1) + + # r0 = a0² + β a1² (option 1) <=> (a0 + a1)(a0 + β a1) - β a0a1 - a0a1 (option 2) + r.c0.prod2x(V1, NonResidue) + r.c0.sum2xMod(r.c0, V0) + + # r1 = 2 a0 a1 (option 1) = (a0 + a1)² - a0² - a1² (option 2) + r.c1.square2x(t) + r.c1.diff2xMod(r.c1, V0) + r.c1.diff2xMod(r.c1, V1) # ############################################################ # # @@ -921,7 +981,19 @@ func square*(r: var QuadraticExt, a: QuadraticExt) = r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) else: - r.square_generic(a) + when true: # r.typeof.F.C in {BLS12_377, BW6_761}: + # BW6-761 requires too many registers for Dbl width path + r.square_generic(a) + else: + # TODO understand why Fp4[BLS12_377] + # is so slow in the branch + # TODO: + # - On Fp4, we can have a.c0.c0 off by p + # a reduction is missing + var d {.noInit.}: doubleWidth(typeof(r)) + d.square2x_disjoint(a.c0, a.c1) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) func prod*(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension diff --git a/tests/t_fp4.nim b/tests/t_fp4.nim new file mode 100644 index 000000000..dba99a16f --- /dev/null +++ b/tests/t_fp4.nim @@ -0,0 +1,129 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + std/unittest, + # Internals + ../constantine/towers, + ../constantine/io/io_towers, + ../constantine/config/curves, + # Test utilities + ./t_fp_tower_template + +const TestCurves = [ + BN254_Nogami, + BN254_Snarks, + BLS12_377, + BLS12_381, + BW6_761 + ] + +runTowerTests( + ExtDegree = 4, + Iters = 12, + TestCurves = TestCurves, + moduleName = "test_fp4", + testSuiteDesc = "𝔽p4 = 𝔽p2[v]" +) + +# Fuzzing failure +# Issue when using Fp4Dbl + +suite "𝔽p4 - Anti-regression": + test "Partial reduction (off by p) on double-width field": + proc partred1() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x0000000000000000000fffffffffffffffffe000000fffffffffcffffff80000", + "0x000000000000007ffffffffff800000001fffe000000000007ffffffffffffe0", + "0x000000c0ff0300fcffffffff7f00000000f0ffffffffffffffff00000000e0ff", + "0x0e0a77c19a07df27e5eea36f7879462c0a7ceb28e5c70b3dd35d438dc58f4d9c" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var p: F + p.prod(x, x) + + # echo "p: ", p.toHex() + # echo "\np raw: ", p + + check: bool(p == s) + + partred1() + + proc partred2() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x0660df54c75b67a0c32fc6208f08b13d8cc86cd93084180725a04884e7f45849", + "0x094185b0915ce1aa3bd3c63d33fd6d9cf3f04ea30fc88efe1e6e9b59117513bb", + "0x26c20beee711e46406372ab4f0e6d0069c67ded0a494bc0301bbfde48f7a4073", + "0x23c60254946def07120e46155466cc9b883b5c3d1c17d1d6516a6268a41dcc5d" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var p: F + p.prod(x, x) + + # echo "p: ", p.toHex() + # echo "\np raw: ", p + + check: bool(p == s) + + partred2() + + + proc partred3() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x233066f735efcf7a0ad6e3ffa3afe4ed39bdfeffffb3f7d8b1fd7eeabfddfb36", + "0x1caba0b27fdfdfd512bdecf3fffbfebdb939fffffffbff8a14e663f7fef7fc85", + "0x212a64f0efefff1b7abe2ebe2bffbfc1b9335fb73ffd7c8815ffffffffffff8d", + "0x212ba4b1ff8feff552a61efff5ffffc5b839f7ffffffff71f477dffe7ffc7e08" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var n, s2: F + n.neg(x) + s2.prod(n, n) + + # echo "s2: ", s2.toHex() + # echo "\ns2 raw: ", s2 + + check: bool(s == s2) + + partred3() diff --git a/tests/t_fp_tower_template.nim b/tests/t_fp_tower_template.nim index a9fa4e4e8..af8d8c5ea 100644 --- a/tests/t_fp_tower_template.nim +++ b/tests/t_fp_tower_template.nim @@ -20,6 +20,7 @@ import ../constantine/towers, ../constantine/config/[common, curves], ../constantine/arithmetic, + ../constantine/io/io_towers, # Test utilities ../helpers/[prng_unsafe, static_for] @@ -28,6 +29,8 @@ echo "\n------------------------------------------------------\n" template ExtField(degree: static int, curve: static Curve): untyped = when degree == 2: Fp2[curve] + elif degree == 4: + Fp4[curve] elif degree == 6: Fp6[curve] elif degree == 12: @@ -273,7 +276,7 @@ proc runTowerTests*[N]( rMul.prod(a, a) rSqr.square(a) - check: bool(rMul == rSqr) + doAssert bool(rMul == rSqr), "Failure with a: " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) @@ -292,7 +295,7 @@ proc runTowerTests*[N]( rSqr.square(a) rNegSqr.square(na) - check: bool(rSqr == rNegSqr) + doAssert bool(rSqr == rNegSqr), "Failure with a: " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) From b5e213beed9992c3fc08dc6bfaf43ca320302efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 00:08:35 +0100 Subject: [PATCH 06/22] less copies and stack space in addition chains --- constantine/arithmetic/finite_fields.nim | 74 +++++++++---------- .../extension_fields.nim | 74 +++++++++---------- 2 files changed, 72 insertions(+), 76 deletions(-) diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 044c9205d..67cabd136 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -389,59 +389,57 @@ func `*=`*(a: var FF, b: static int) = elif b == 2: a.double() elif b == 3: - let t1 = a - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + a += t elif b == 4: a.double() a.double() elif b == 5: - let t1 = a - a.double() - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t elif b == 6: - a.double() - let t2 = a - a.double() # 4 - a += t2 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + a.double(t) elif b == 7: - let t1 = a - a.double() - let t2 = a - a.double() # 4 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.diff(t, a) elif b == 8: a.double() a.double() a.double() elif b == 9: - let t1 = a - a.double() - a.double() - a.double() # 8 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.sum(t, a) elif b == 10: + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t # 5 a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 elif b == 11: - let t1 = a - a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + t.double() # 12 + a.diff(t, a) # 11 elif b == 12: - a.double() - a.double() # 4 - let t4 = a - a.double() # 8 - a += t4 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + a.double(t) # 12 else: {.error: "Multiplication by this small int not implemented".} diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index a35229e4e..8596cda31 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -235,59 +235,57 @@ func `*=`*(a: var ExtensionField, b: static int) = elif b == 2: a.double() elif b == 3: - let t1 = a - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + a += t elif b == 4: a.double() a.double() elif b == 5: - let t1 = a - a.double() - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t elif b == 6: - a.double() - let t2 = a - a.double() # 4 - a += t2 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + a.double(t) elif b == 7: - let t1 = a - a.double() - let t2 = a - a.double() # 4 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.diff(t, a) elif b == 8: a.double() a.double() a.double() elif b == 9: - let t1 = a - a.double() - a.double() - a.double() # 8 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.sum(t, a) elif b == 10: + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t # 5 a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 elif b == 11: - let t1 = a - a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + t.double() # 12 + a.diff(t, a) # 11 elif b == 12: - a.double() - a.double() # 4 - let t4 = a - a.double() # 8 - a += t4 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + a.double(t) # 12 else: {.error: "Multiplication by this small int not implemented".} From 3929d198d1ec2ecbd28d371897437d5b4093644e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 00:29:22 +0100 Subject: [PATCH 07/22] Address https://github.com/mratsim/constantine/issues/154 partly --- .../tower_instantiation.nim | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/constantine/tower_field_extensions/tower_instantiation.nim b/constantine/tower_field_extensions/tower_instantiation.nim index cdcbc1a71..bb337d1db 100644 --- a/constantine/tower_field_extensions/tower_instantiation.nim +++ b/constantine/tower_field_extensions/tower_instantiation.nim @@ -116,17 +116,24 @@ func prod*(r: var Fp2, a: Fp2, _: type NonResidue) {.inline.} = # BLS12_377 and BW6_761, use small addition chain r.mul_sparse_by_0y(a, v) else: - # BN254_Snarks, u = 9 - # Full 𝔽p2 multiplication is cheaper than addition chains - # for u*c0 and u*c1 - static: - doAssert u >= 0 and uint64(u) <= uint64(high(BaseType)) - doAssert v >= 0 and uint64(v) <= uint64(high(BaseType)) - # TODO: compile-time - var NR {.noInit.}: Fp2 - NR.c0.fromUint(uint u) - NR.c1.fromUint(uint v) - r.prod(a, NR) + # BN254_Snarks, u = 9, v = 1, β = -1 + # Even with u = 9, the 2x9 addition chains (8 additions total) + # are cheaper than full Fp2 multiplication + var t {.noInit.}: typeof(a.c0) + + t.prod(a.c0, u) + when v == 1 and Beta == -1: # Case BN254_Snarks + t -= a.c1 # r0 = u c0 + β v c1 + else: + {.error: "Unimplemented".} + + r.c1.prod(a.c1, u) + when v == 1: # r1 = v c0 + u c1 + r.c1 += a.c0 + # aliasing: a.c0 is unused + r.c0 = t + else: + {.error: "Unimplemented".} func `*=`*(a: var Fp2, _: type NonResidue) {.inline.} = ## Multiply an element of 𝔽p2 by the non-residue From f8f051aeb13b83651eae54a013b5ffac170ad7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 01:09:32 +0100 Subject: [PATCH 08/22] Fix #154, faster Fp4 square: less non-residue, no Mul, only square (bit more ops total) --- .../extension_fields.nim | 73 ++++++++++++++----- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 8596cda31..dee15443f 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -807,31 +807,68 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) = # # Alternative 2: # c0² + β c1² <=> (c0 + c1)(c0 + β c1) - β c0c1 - c0c1 - mixin prod - var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) + # + # This gives us 2 Mul and 2 mul-nonresidue (which is costly for BN254_Snarks) + # + # We can also reframe the 2nd term with only squarings + # which might be significantly faster on higher tower degrees + # + # 2 c0 c1 <=> (a0 + a1)² - a0² - a1² + # + # This gives us 3 Sqr and 1 Mul-non-residue + const costlyMul = block: + # No shortcutting in the VM :/ + when a.c0 is ExtensionField: + when a.c0.c0 is ExtensionField: + true + else: + false + else: + false - # v1 <- (c0 + β c1) - v1.prod(a.c1, NonResidue) - v1 += a.c0 + when QuadraticExt.C == BN254_Snarks or costlyMul: + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) + v0.square(a.c0) + v1.square(a.c1) - # v0 <- (c0 + c1)(c0 + β c1) - v0.sum(a.c0, a.c1) - v0 *= v1 + # Aliasing: a unneeded now + r.c1.diff(a.c0, a.c1) - # v1 <- c0 c1 - v1.prod(a.c0, a.c1) + # r0 = c0² + β c1² + r.c0.prod(v1, NonResidue) + r.c0 += v0 - # aliasing: a unneeded now + # r1 = (a0 + a1)² - a0² - a1² + r.c1.square() + r.c1 -= v0 + r.c1 -= v1 + + else: + mixin prod + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) - # r0 = (c0 + c1)(c0 + β c1) - c0c1 - v0 -= v1 + # v1 <- (c0 + β c1) + v1.prod(a.c1, NonResidue) + v1 += a.c0 - # r1 = 2 c0c1 - r.c1.double(v1) + # v0 <- (c0 + c1)(c0 + β c1) + v0.sum(a.c0, a.c1) + v0 *= v1 - # r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1 - v1 *= NonResidue - r.c0.diff(v0, v1) + # v1 <- c0 c1 + v1.prod(a.c0, a.c1) + + # aliasing: a unneeded now + + # r0 = (c0 + c1)(c0 + β c1) - c0c1 + v0 -= v1 + + # r1 = 2 c0c1 + r.c1.double(v1) + + # r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1 + v1 *= NonResidue + r.c0.diff(v0, v1) func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = ## Returns r = a * b From bb4a3793ce667a75069b33c02e2429fa47522adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 01:16:50 +0100 Subject: [PATCH 09/22] Fix typo --- constantine/tower_field_extensions/extension_fields.nim | 4 ++-- tests/t_fp_tower_template.nim | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index dee15443f..087fd1949 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -832,7 +832,7 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) = v1.square(a.c1) # Aliasing: a unneeded now - r.c1.diff(a.c0, a.c1) + r.c1.sum(a.c0, a.c1) # r0 = c0² + β c1² r.c0.prod(v1, NonResidue) @@ -842,7 +842,7 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) = r.c1.square() r.c1 -= v0 r.c1 -= v1 - + else: mixin prod var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) diff --git a/tests/t_fp_tower_template.nim b/tests/t_fp_tower_template.nim index af8d8c5ea..1e8541c59 100644 --- a/tests/t_fp_tower_template.nim +++ b/tests/t_fp_tower_template.nim @@ -276,7 +276,7 @@ proc runTowerTests*[N]( rMul.prod(a, a) rSqr.square(a) - doAssert bool(rMul == rSqr), "Failure with a: " & a.toHex() + doAssert bool(rMul == rSqr), "Failure with a (" & $Field & "): " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) @@ -295,7 +295,7 @@ proc runTowerTests*[N]( rSqr.square(a) rNegSqr.square(na) - doAssert bool(rSqr == rNegSqr), "Failure with a: " & a.toHex() + doAssert bool(rSqr == rNegSqr), "Failure with a (" & $Field & "): " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) From 6b09da0610f7abd45dcaa830dfbf816de4f845d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 18:37:17 +0100 Subject: [PATCH 10/22] better assembly scheduling for add/sub --- .../arithmetic/assembly/limbs_asm_x86.nim | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/constantine/arithmetic/assembly/limbs_asm_x86.nim b/constantine/arithmetic/assembly/limbs_asm_x86.nim index 8a73176c6..8321bcdfc 100644 --- a/constantine/arithmetic/assembly/limbs_asm_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_x86.nim @@ -138,14 +138,16 @@ macro add_gen[N: static int](carry: var Carry, r: var Limbs[N], a, b: Limbs[N]): var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType # Algorithm - for i in 0 ..< N: - ctx.mov t0, arrA[i] - if i == 0: - ctx.add t0, arrB[0] - else: - ctx.adc t0, arrB[i] - ctx.mov arrR[i], t0 - swap(t0, t1) + ctx.mov t0, arrA[0] # Prologue + ctx.add t0, arrB[0] + + for i in 1 ..< N: + ctx.mov t1, arrA[i] # Prepare the next iteration + ctx.mov arrR[i-1], t0 # Save the previous reult in an interleaved manner + ctx.adc t1, arrB[i] # Compute + swap(t0, t1) # Break dependency chain + + ctx.mov arrR[N-1], t0 # Epilogue ctx.setToCarryFlag(carry) # Codegen @@ -197,16 +199,18 @@ macro sub_gen[N: static int](borrow: var Borrow, r: var Limbs[N], a, b: Limbs[N] var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType # Algorithm - for i in 0 ..< N: - ctx.mov t0, arrA[i] - if i == 0: - ctx.sub t0, arrB[0] - else: - ctx.sbb t0, arrB[i] - ctx.mov arrR[i], t0 - swap(t0, t1) - ctx.setToCarryFlag(borrow) + ctx.mov t0, arrA[0] # Prologue + ctx.sub t0, arrB[0] + for i in 1 ..< N: + ctx.mov t1, arrA[i] # Prepare the next iteration + ctx.mov arrR[i-1], t0 # Save the previous reult in an interleaved manner + ctx.sbb t1, arrB[i] # Compute + swap(t0, t1) # Break dependency chain + + ctx.mov arrR[N-1], t0 # Epilogue + ctx.setToCarryFlag(borrow) + # Codegen result.add ctx.generate From c0a5000320a6fd572167f8e8e409a0a38528a61d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 18:50:01 +0100 Subject: [PATCH 11/22] Double-width -> Double-precision --- ...width.nim => bench_fp_double_precision.nim} | 12 ++++++------ constantine.nimble | 14 +++++++------- constantine/arithmetic.nim | 4 ++-- ....nim => limbs_asm_modular_dbl_prec_x86.nim} | 18 +++++++++++++----- ....nim => finite_fields_double_precision.nim} | 18 ++++++++++++------ .../extension_fields.nim | 18 +++++++++--------- ...im => t_finite_fields_double_precision.nim} | 2 +- 7 files changed, 50 insertions(+), 36 deletions(-) rename benchmarks/{bench_fp_double_width.nim => bench_fp_double_precision.nim} (96%) rename constantine/arithmetic/assembly/{limbs_asm_modular_dbl_width_x86.nim => limbs_asm_modular_dbl_prec_x86.nim} (83%) rename constantine/arithmetic/{finite_fields_double_width.nim => finite_fields_double_precision.nim} (87%) rename tests/{t_finite_fields_double_width.nim => t_finite_fields_double_precision.nim} (98%) diff --git a/benchmarks/bench_fp_double_width.nim b/benchmarks/bench_fp_double_precision.nim similarity index 96% rename from benchmarks/bench_fp_double_width.nim rename to benchmarks/bench_fp_double_precision.nim index 87ff4f37a..86d01c06c 100644 --- a/benchmarks/bench_fp_double_width.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -150,19 +150,19 @@ proc diff(T: typedesc, iters: int) = r.diff(a, b) proc diff2xNoReduce(T: typedesc, iters: int) = - var r, a, b: doubleWidth(T) + var r, a, b: doublePrec(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x unreduced", $doubleWidth(T), iters): + bench("Substraction 2x unreduced", $doublePrec(T), iters): r.diff2xUnred(a, b) proc diff2x(T: typedesc, iters: int) = - var r, a, b: doubleWidth(T) + var r, a, b: doublePrec(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x reduced", $doubleWidth(T), iters): + bench("Substraction 2x reduced", $doublePrec(T), iters): r.diff2xMod(a, b) proc prod2xBench*(rLen, aLen, bLen: static int, iters: int) = @@ -180,10 +180,10 @@ proc square2xBench*(rLen, aLen: static int, iters: int) = proc reduce2x*(T: typedesc, iters: int) = var r: T - var t: doubleWidth(T) + var t: doublePrec(T) rng.random_unsafe(t, T) - bench("Redc 2x", $T & " <- " & $doubleWidth(T), iters): + bench("Redc 2x", $T & " <- " & $doublePrec(T), iters): r.redc2x(t) proc main() = diff --git a/constantine.nimble b/constantine.nimble index bb3d0bb90..2d5223541 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -45,7 +45,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ ("tests/t_fp_cubic_root.nim", false), # Double-width finite fields # ---------------------------------------------------------- - ("tests/t_finite_fields_double_width.nim", false), + ("tests/t_finite_fields_double_precision.nim", false), # Towers of extension fields # ---------------------------------------------------------- ("tests/t_fp2.nim", false), @@ -260,7 +260,7 @@ proc buildAllBenches() = echo "\n\n------------------------------------------------------\n" echo "Building benchmarks to ensure they stay relevant ..." buildBench("bench_fp") - buildBench("bench_fp_double_width") + buildBench("bench_fp_double_precision") buildBench("bench_fp2") buildBench("bench_fp6") buildBench("bench_fp12") @@ -401,19 +401,19 @@ task bench_fp_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly": runBench("bench_fp", "clang", useAsm = false) task bench_fpdbl, "Run benchmark 𝔽pDbl with your default compiler": - runBench("bench_fp_double_width") + runBench("bench_fp_double_precision") task bench_fpdbl_gcc, "Run benchmark 𝔽p with gcc": - runBench("bench_fp_double_width", "gcc") + runBench("bench_fp_double_precision", "gcc") task bench_fpdbl_clang, "Run benchmark 𝔽p with clang": - runBench("bench_fp_double_width", "clang") + runBench("bench_fp_double_precision", "clang") task bench_fpdbl_gcc_noasm, "Run benchmark 𝔽p with gcc - no Assembly": - runBench("bench_fp_double_width", "gcc", useAsm = false) + runBench("bench_fp_double_precision", "gcc", useAsm = false) task bench_fpdbl_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly": - runBench("bench_fp_double_width", "clang", useAsm = false) + runBench("bench_fp_double_precision", "clang", useAsm = false) task bench_fp2, "Run benchmark with 𝔽p2 your default compiler": runBench("bench_fp2") diff --git a/constantine/arithmetic.nim b/constantine/arithmetic.nim index 7e5339a35..540cc6e20 100644 --- a/constantine/arithmetic.nim +++ b/constantine/arithmetic.nim @@ -12,7 +12,7 @@ import finite_fields, finite_fields_inversion, finite_fields_square_root, - finite_fields_double_width + finite_fields_double_precision ] export @@ -21,4 +21,4 @@ export finite_fields, finite_fields_inversion, finite_fields_square_root, - finite_fields_double_width + finite_fields_double_precision diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_width_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim similarity index 83% rename from constantine/arithmetic/assembly/limbs_asm_modular_dbl_width_x86.nim rename to constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim index 54f6b4c84..f53db7d93 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_width_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -14,17 +14,25 @@ import ../../primitives # ############################################################ -# -# Assembly implementation of FpDbl -# +# # +# Assembly implementation of FpDbl # +# # # ############################################################ +# A FpDbl is a partially-reduced double-precision element of Fp +# The allowed range is [0, 2ⁿp) +# with n = w*WordBitSize +# and w the number of words necessary to represent p on the machine. +# Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total) +# and so FpDbl would 768 bits. + static: doAssert UseASM_X86_64 {.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers -# TODO slower than intrinsics +# Field addition +# ------------------------------------------------------------ -# Substraction +# Field Substraction # ------------------------------------------------------------ macro sub2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = diff --git a/constantine/arithmetic/finite_fields_double_width.nim b/constantine/arithmetic/finite_fields_double_precision.nim similarity index 87% rename from constantine/arithmetic/finite_fields_double_width.nim rename to constantine/arithmetic/finite_fields_double_precision.nim index a2d59de6a..602a07c26 100644 --- a/constantine/arithmetic/finite_fields_double_width.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -16,16 +16,22 @@ import ./limbs_montgomery when UseASM_X86_64: - import assembly/limbs_asm_modular_dbl_width_x86 + import assembly/limbs_asm_modular_dbl_prec_x86 type FpDbl*[C: static Curve] = object - ## Double-width Fp element - ## This allows saving on reductions - # We directly work with double the number of limbs + ## Double-precision Fp element + ## A FpDbl is a partially-reduced double-precision element of Fp + ## The allowed range is [0, 2ⁿp) + ## with n = w*WordBitSize + ## and w the number of words necessary to represent p on the machine. + ## Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total) + ## and so FpDbl would 768 bits. + # We directly work with double the number of limbs, + # instead of BigInt indirection. limbs2x*: matchingLimbs2x(C) -template doubleWidth*(T: type Fp): type = - ## Return the double-width type matching with Fp +template doublePrec*(T: type Fp): type = + ## Return the double-precision type matching with Fp FpDbl[T.C] # No exceptions allowed diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 087fd1949..ab702f47e 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -319,17 +319,17 @@ type ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F] -template doubleWidth(T: type ExtensionField): type = +template doublePrec(T: type ExtensionField): type = # For now naive unrolling, recursive template don't match # and I don't want to deal with types in macros when T is QuadraticExt: when T.F is QuadraticExt: # Fp4Dbl - QuadraticExt2x[QuadraticExt2x[doubleWidth(T.F.F)]] + QuadraticExt2x[QuadraticExt2x[doublePrec(T.F.F)]] elif T.F is Fp: # Fp2Dbl - QuadraticExt2x[doubleWidth(T.F)] + QuadraticExt2x[doublePrec(T.F)] elif T is CubicExt: when T.F is QuadraticExt: # Fp6Dbl - CubicExt2x[QuadraticExt2x[doubleWidth(T.F.F)]] + CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]] func has1extraBit(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) @@ -603,7 +603,7 @@ func prod2x_disjoint[Fdbl, F]( a: QuadraticExt[F], b0, b1: F) = ## Return a * (b0, b1) in r - static: doAssert Fdbl is doubleWidth(F) + static: doAssert Fdbl is doublePrec(F) var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-width var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width @@ -1011,7 +1011,7 @@ func square*(r: var QuadraticExt, a: QuadraticExt) = when true: r.square_complex(a) else: # slower - var d {.noInit.}: doubleWidth(typeof(r)) + var d {.noInit.}: doublePrec(typeof(r)) d.square2x_complex(a) r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) @@ -1025,7 +1025,7 @@ func square*(r: var QuadraticExt, a: QuadraticExt) = # TODO: # - On Fp4, we can have a.c0.c0 off by p # a reduction is missing - var d {.noInit.}: doubleWidth(typeof(r)) + var d {.noInit.}: doublePrec(typeof(r)) d.square2x_disjoint(a.c0, a.c1) r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) @@ -1036,7 +1036,7 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = when false: r.prod_complex(a, b) else: # faster - var d {.noInit.}: doubleWidth(typeof(r)) + var d {.noInit.}: doublePrec(typeof(r)) d.prod2x_complex(a, b) r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) @@ -1044,7 +1044,7 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = when true: # typeof(r.c0) is Fp: r.prod_generic(a, b) else: # Deactivated r.c0 is correct modulo p but >= p - var d {.noInit.}: doubleWidth(typeof(r)) + var d {.noInit.}: doublePrec(typeof(r)) d.prod2x_disjoint(a, b.c0, b.c1) r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) diff --git a/tests/t_finite_fields_double_width.nim b/tests/t_finite_fields_double_precision.nim similarity index 98% rename from tests/t_finite_fields_double_width.nim rename to tests/t_finite_fields_double_precision.nim index 8435022b5..11002b7f6 100644 --- a/tests/t_finite_fields_double_width.nim +++ b/tests/t_finite_fields_double_precision.nim @@ -22,7 +22,7 @@ var rng: RngState let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 rng.seed(seed) echo "\n------------------------------------------------------\n" -echo "test_finite_fields_double_width xoshiro512** seed: ", seed +echo "test_finite_fields_double_precision xoshiro512** seed: ", seed template mulTest(rng_gen: untyped): untyped = proc `mul _ rng_gen`(C: static Curve) = From 670294135e6441b26c808bdf5535699a438d2da0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 18:52:49 +0100 Subject: [PATCH 12/22] Unred -> Unr --- benchmarks/bench_fp_double_precision.nim | 14 +++++----- constantine/arithmetic/finite_fields.nim | 4 +-- .../finite_fields_double_precision.nim | 4 +-- .../extension_fields.nim | 26 +++++++++---------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/benchmarks/bench_fp_double_precision.nim b/benchmarks/bench_fp_double_precision.nim index 86d01c06c..d62c64854 100644 --- a/benchmarks/bench_fp_double_precision.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -121,12 +121,12 @@ func random_unsafe(rng: var RngState, a: var FpDbl, Base: typedesc) = for i in 0 ..< aHi.mres.limbs.len: a.limbs2x[aLo.mres.limbs.len+i] = aHi.mres.limbs[i] -proc sumUnred(T: typedesc, iters: int) = +proc sumUnr(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) bench("Addition unreduced", $T, iters): - r.sumUnred(a, b) + r.sumUnr(a, b) proc sum(T: typedesc, iters: int) = var r: T @@ -135,12 +135,12 @@ proc sum(T: typedesc, iters: int) = bench("Addition", $T, iters): r.sum(a, b) -proc diffUnred(T: typedesc, iters: int) = +proc diffUnr(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) bench("Substraction unreduced", $T, iters): - r.diffUnred(a, b) + r.diffUnr(a, b) proc diff(T: typedesc, iters: int) = var r: T @@ -155,7 +155,7 @@ proc diff2xNoReduce(T: typedesc, iters: int) = rng.random_unsafe(a, T) rng.random_unsafe(b, T) bench("Substraction 2x unreduced", $doublePrec(T), iters): - r.diff2xUnred(a, b) + r.diff2xUnr(a, b) proc diff2x(T: typedesc, iters: int) = var r, a, b: doublePrec(T) @@ -188,8 +188,8 @@ proc reduce2x*(T: typedesc, iters: int) = proc main() = separator() - sumUnred(Fp[BLS12_381], iters = 10_000_000) - diffUnred(Fp[BLS12_381], iters = 10_000_000) + sumUnr(Fp[BLS12_381], iters = 10_000_000) + diffUnr(Fp[BLS12_381], iters = 10_000_000) sum(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000) diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 67cabd136..e89cc3c45 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -169,7 +169,7 @@ func sum*(r: var FF, a, b: FF) {.meter.} = overflowed = overflowed or not(r.mres < FF.fieldMod()) discard csub(r.mres, FF.fieldMod(), overflowed) -func sumUnred*(r: var FF, a, b: FF) {.meter.} = +func sumUnr*(r: var FF, a, b: FF) {.meter.} = ## Sum ``a`` and ``b`` into ``r`` without reduction discard r.mres.sum(a.mres, b.mres) @@ -183,7 +183,7 @@ func diff*(r: var FF, a, b: FF) {.meter.} = var underflowed = r.mres.diff(a.mres, b.mres) discard cadd(r.mres, FF.fieldMod(), underflowed) -func diffUnred*(r: var FF, a, b: FF) {.meter.} = +func diffUnr*(r: var FF, a, b: FF) {.meter.} = ## Substract `b` from `a` and store the result into `r` ## without reduction discard r.mres.diff(a.mres, b.mres) diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim index 602a07c26..e9f41fa9f 100644 --- a/constantine/arithmetic/finite_fields_double_precision.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -65,7 +65,7 @@ func redc2x*(r: var Fp, a: FpDbl) = Fp.canUseNoCarryMontyMul() ) -func diff2xUnred*(r: var FpDbl, a, b: FpDbl) = +func diff2xUnr*(r: var FpDbl, a, b: FpDbl) = ## Double-width substraction without reduction discard r.limbs2x.diff(a.limbs2x, b.limbs2x) @@ -84,7 +84,7 @@ func diff2xMod*(r: var FpDbl, a, b: FpDbl) = addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) underflowed.ccopy(r.limbs2x[i+N], sum) -func sum2xUnred*(r: var FpDbl, a, b: FpDbl) = +func sum2xUnr*(r: var FpDbl, a, b: FpDbl) = ## Double-width addition without reduction discard r.limbs2x.sum(a.limbs2x, b.limbs2x) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index ab702f47e..c6cbec224 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -367,25 +367,25 @@ func setZero*(a: var ExtensionField2x) = # Abelian group # ------------------------------------------------------------------- -func sumUnred(r: var ExtensionField, a, b: ExtensionField) = +func sumUnr(r: var ExtensionField, a, b: ExtensionField) = ## Sum ``a`` and ``b`` into ``r`` staticFor i, 0, a.coords.len: - r.coords[i].sumUnred(a.coords[i], b.coords[i]) + r.coords[i].sumUnr(a.coords[i], b.coords[i]) -func diff2xUnred(r: var ExtensionField2x, a, b: ExtensionField2x) = +func diff2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-width substraction without reduction staticFor i, 0, a.coords.len: - r.coords[i].diff2xUnred(a.coords[i], b.coords[i]) + r.coords[i].diff2xUnr(a.coords[i], b.coords[i]) func diff2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-width modular substraction staticFor i, 0, a.coords.len: r.coords[i].diff2xMod(a.coords[i], b.coords[i]) -func sum2xUnred(r: var ExtensionField2x, a, b: ExtensionField2x) = +func sum2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-width addition without reduction staticFor i, 0, a.coords.len: - r.coords[i].sum2xUnred(a.coords[i], b.coords[i]) + r.coords[i].sum2xUnr(a.coords[i], b.coords[i]) func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = ## Double-width modular addition @@ -532,15 +532,15 @@ func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = r.c0.prod2x(a.c0, b.c0) # r0 = a0 b0 D.prod2x(a.c1, b.c1) # d = a1 b1 when QuadraticExt.has1extraBit(): - t0.sumUnred(a.c0, a.c1) - t1.sumUnred(b.c0, b.c1) + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b.c0, b.c1) else: t0.sum(a.c0, a.c1) t1.sum(b.c0, b.c1) r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) when QuadraticExt.has1extraBit(): - r.c1.diff2xUnred(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - r.c1.diff2xUnred(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 + r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 + r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 else: r.c1.diff2xMod(r.c1, r.c0) r.c1.diff2xMod(r.c1, D) @@ -556,7 +556,7 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = # Require 2 extra bits when QuadraticExt.has2extraBits(): - t0.sumUnred(a.c1, a.c1) + t0.sumUnr(a.c1, a.c1) t1.sum(a.c0, a.c1) else: t0.double(a.c1) @@ -612,8 +612,8 @@ func prod2x_disjoint[Fdbl, F]( V0.prod2x(a.c0, b0) V1.prod2x(a.c1, b1) when F.has2extraBits(): - t0.sumUnred(b0, b1) - t1.sumUnred(a.c0, a.c1) + t0.sumUnr(b0, b1) + t1.sumUnr(a.c0, a.c1) else: t0.sum(b0, b1) t1.sum(a.c0, a.c1) From 80434b23006b267e55f0ea01615b2cfbc53247aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 19:37:23 +0100 Subject: [PATCH 13/22] double-precision modular addition --- benchmarks/bench_fp_double_precision.nim | 28 +++++- constantine.nimble | 2 +- .../limbs_asm_modular_dbl_prec_x86.nim | 86 +++++++++++++++++-- .../assembly/limbs_asm_modular_x86.nim | 2 +- .../finite_fields_double_precision.nim | 37 ++++++-- constantine/arithmetic/limbs_montgomery.nim | 8 +- .../extension_fields.nim | 16 ++-- docs/optimizations.md | 4 +- tests/t_finite_fields_double_precision.nim | 4 +- tests/t_fp4.nim | 2 +- 10 files changed, 154 insertions(+), 35 deletions(-) diff --git a/benchmarks/bench_fp_double_precision.nim b/benchmarks/bench_fp_double_precision.nim index d62c64854..db4a4f25d 100644 --- a/benchmarks/bench_fp_double_precision.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -149,7 +149,23 @@ proc diff(T: typedesc, iters: int) = bench("Substraction", $T, iters): r.diff(a, b) -proc diff2xNoReduce(T: typedesc, iters: int) = +proc sum2xUnreduce(T: typedesc, iters: int) = + var r, a, b: doublePrec(T) + rng.random_unsafe(r, T) + rng.random_unsafe(a, T) + rng.random_unsafe(b, T) + bench("Addition 2x unreduced", $doublePrec(T), iters): + r.sum2xUnr(a, b) + +proc sum2x(T: typedesc, iters: int) = + var r, a, b: doublePrec(T) + rng.random_unsafe(r, T) + rng.random_unsafe(a, T) + rng.random_unsafe(b, T) + bench("Addition 2x reduced", $doublePrec(T), iters): + r.sum2xMod(a, b) + +proc diff2xUnreduce(T: typedesc, iters: int) = var r, a, b: doublePrec(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) @@ -188,12 +204,16 @@ proc reduce2x*(T: typedesc, iters: int) = proc main() = separator() - sumUnr(Fp[BLS12_381], iters = 10_000_000) - diffUnr(Fp[BLS12_381], iters = 10_000_000) sum(Fp[BLS12_381], iters = 10_000_000) + sumUnr(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000) + diffUnr(Fp[BLS12_381], iters = 10_000_000) + separator() + sum2x(Fp[BLS12_381], iters = 10_000_000) + sum2xUnreduce(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000) - diff2xNoReduce(Fp[BLS12_381], iters = 10_000_000) + diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000) + separator() prod2xBench(768, 384, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000) diff --git a/constantine.nimble b/constantine.nimble index 2d5223541..687a17e93 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -43,7 +43,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ ("tests/t_finite_fields_powinv.nim", false), ("tests/t_finite_fields_vs_gmp.nim", true), ("tests/t_fp_cubic_root.nim", false), - # Double-width finite fields + # Double-precision finite fields # ---------------------------------------------------------- ("tests/t_finite_fields_double_precision.nim", false), # Towers of extension fields diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim index f53db7d93..5edbf23a2 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -32,11 +32,85 @@ static: doAssert UseASM_X86_64 # Field addition # ------------------------------------------------------------ + +macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = + ## Generate an optimized out-of-place double-precision addition kernel + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + H = N div 2 + + r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput) + # We reuse the reg used for b for overflow detection + b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput) + # We could force m as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input) + # If N is too big, we need to spill registers. TODO. + u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput) + v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput) + + let usym = u.nimSymbol + let vsym = v.nimSymbol + result.add quote do: + var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`) + staticFor i, 0, `H`: + `usym`[i] = `A`[i] + staticFor i, `H`, `N`: + `vsym`[i-`H`] = `A`[i] + + # Addition + # u = a[0.. Date: Mon, 8 Feb 2021 21:01:35 +0100 Subject: [PATCH 14/22] Replace canUseNoCarryMontyMul and canUseNoCarryMontySquare by getSpareBits --- .../limbs_asm_modular_dbl_prec_x86.nim | 2 +- .../assembly/limbs_asm_modular_x86.nim | 2 +- .../assembly/limbs_asm_montred_x86.nim | 8 +-- .../limbs_asm_montred_x86_adx_bmi2.nim | 8 +-- constantine/arithmetic/bigints_montgomery.nim | 32 +++++------ constantine/arithmetic/finite_fields.nim | 22 ++++---- .../finite_fields_double_precision.nim | 2 +- constantine/arithmetic/limbs_montgomery.nim | 54 +++++++++---------- constantine/config/curves_derived.nim | 14 ++--- constantine/config/curves_prop_derived.nim | 21 ++++---- constantine/config/precompute.nim | 29 +++++----- .../extension_fields.nim | 4 +- helpers/prng_unsafe.nim | 6 +-- tests/t_finite_fields_mulsquare.nim | 11 ++-- tests/t_fr.nim | 6 +-- 15 files changed, 105 insertions(+), 116 deletions(-) diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim index 5edbf23a2..322e219b2 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -77,7 +77,7 @@ macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N di ctx.mov u[i-H], v[i-H] # Mask: overflowed contains 0xFFFF or 0x0000 - # TODO: unnecessary if MSB never set, i.e. "canUseNoCarryMontyMul" + # TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1" let overflowed = b.reuseRegister() ctx.sbb overflowed, overflowed diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim index b5a9b2d87..0d2883ad9 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim @@ -69,7 +69,7 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped = ctx.mov v[i], u[i] # Mask: overflowed contains 0xFFFF or 0x0000 - # TODO: unnecessary if MSB never set, i.e. "canUseNoCarryMontyMul" + # TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1" let overflowed = b.reuseRegister() ctx.sbb overflowed, overflowed diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim index 5f119f756..8e48de73e 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim @@ -88,7 +88,7 @@ macro montyRedc2x_gen[N: static int]( a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], m0ninv_MR: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = # TODO, slower than Clang, in particular due to the shadowing @@ -236,7 +236,7 @@ macro montyRedc2x_gen[N: static int]( let reuse = repackRegisters(t, scratch[N], scratch[N+1]) - if canUseNoCarryMontyMul: + if spareBits >= 1: ctx.finalSubNoCarry(r, scratch, M, reuse) else: ctx.finalSubCanOverflow(r, scratch, M, reuse, rRAX) @@ -249,7 +249,7 @@ func montRed_asm*[N: static int]( a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = ## Constant-time Montgomery reduction - montyRedc2x_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2x_gen(r, a, M, m0ninv, spareBits) diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim index 746cefc05..c4e652019 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim @@ -40,7 +40,7 @@ macro montyRedc2xx_gen[N: static int]( a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], m0ninv_MR: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = # TODO, slower than Clang, in particular due to the shadowing @@ -175,7 +175,7 @@ macro montyRedc2xx_gen[N: static int]( let reuse = repackRegisters(t, scratch[N]) - if canUseNoCarryMontyMul: + if spareBits >= 1: ctx.finalSubNoCarry(r, scratch, M, reuse) else: ctx.finalSubCanOverflow(r, scratch, M, reuse, hi) @@ -188,7 +188,7 @@ func montRed_asm_adx_bmi2*[N: static int]( a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = ## Constant-time Montgomery reduction - montyRedc2xx_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2xx_gen(r, a, M, m0ninv, spareBits) diff --git a/constantine/arithmetic/bigints_montgomery.nim b/constantine/arithmetic/bigints_montgomery.nim index 489d610b1..b28e9a630 100644 --- a/constantine/arithmetic/bigints_montgomery.nim +++ b/constantine/arithmetic/bigints_montgomery.nim @@ -25,7 +25,7 @@ import # # ############################################################ -func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = +func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) = ## Convert a BigInt from its natural representation ## to the Montgomery n-residue form ## @@ -40,9 +40,9 @@ func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseTy ## - `r2modM` is R² (mod M) ## with W = M.len ## and R = (2^WordBitWidth)^W - montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, canUseNoCarryMontyMul) + montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits) -func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = +func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) = ## Convert a BigInt from its Montgomery n-residue form ## to the natural representation ## @@ -54,26 +54,26 @@ func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static Base var one {.noInit.}: BigInt[mBits] one.setOne() one - redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, canUseNoCarryMontyMul) + redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, spareBits) -func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = +func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = ## Compute r <- a*b (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) + montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits) -func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = +func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = ## Compute r <- a^2 (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) + montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits) func montyPow*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -92,12 +92,12 @@ func montyPow*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) func montyPowUnsafeExponent*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -116,7 +116,7 @@ func montyPowUnsafeExponent*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) from ../io/io_bigints import exportRawUint # Workaround recursive dependencies @@ -124,7 +124,7 @@ from ../io/io_bigints import exportRawUint func montyPow*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -138,12 +138,12 @@ func montyPow*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPow(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPow(a, expBE, M, one, negInvModWord, windowSize, spareBits) func montyPowUnsafeExponent*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -161,7 +161,7 @@ func montyPowUnsafeExponent*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, spareBits) {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index e89cc3c45..a219051fb 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -56,7 +56,7 @@ func fromBig*(dst: var FF, src: BigInt) = when nimvm: dst.mres.montyResidue_precompute(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord()) else: - dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = ## Convert a BigInt to its Montgomery form @@ -65,7 +65,7 @@ func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = func toBig*(src: FF): auto {.noInit, inline.} = ## Convert a finite-field element to a BigInt in natural representation var r {.noInit.}: typeof(src.mres) - r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) return r # Copy @@ -201,11 +201,11 @@ func double*(r: var FF, a: FF) {.meter.} = func prod*(r: var FF, a, b: FF) {.meter.} = ## Store the product of ``a`` by ``b`` modulo p into ``r`` ## ``r`` is initialized / overwritten - r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func square*(r: var FF, a: FF) {.meter.} = ## Squaring modulo p - r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) + r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func neg*(r: var FF, a: FF) {.meter.} = ## Negate modulo p @@ -279,8 +279,7 @@ func pow*(a: var FF, exponent: BigInt) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func pow*(a: var FF, exponent: openarray[byte]) = @@ -292,8 +291,7 @@ func pow*(a: var FF, exponent: openarray[byte]) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func powUnsafeExponent*(a: var FF, exponent: BigInt) = @@ -312,8 +310,7 @@ func powUnsafeExponent*(a: var FF, exponent: BigInt) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = @@ -332,8 +329,7 @@ func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) # ############################################################ @@ -350,7 +346,7 @@ func `*=`*(a: var FF, b: FF) {.meter.} = func square*(a: var FF) {.meter.} = ## Squaring modulo p - a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) + a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func square_repeated*(r: var FF, num: int) {.meter.} = ## Repeated squarings diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim index 620596f7e..9b1d1c432 100644 --- a/constantine/arithmetic/finite_fields_double_precision.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -75,7 +75,7 @@ func redc2x*(r: var Fp, a: FpDbl) = a.limbs2x, Fp.C.Mod.limbs, Fp.getNegInvModWord(), - Fp.canUseNoCarryMontyMul() + Fp.getSpareBits() ) func diff2xUnr*(r: var FpDbl, a, b: FpDbl) = diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index eac18b281..a455c5141 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -392,7 +392,7 @@ func montyRedc2x_Comba[N: static int]( func montyMul*( r: var Limbs, a, b, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = + m0ninv: static BaseType, spareBits: static int) {.inline.} = ## Compute r <- a*b (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64 ## @@ -419,7 +419,7 @@ func montyMul*( # The implementation is visible from here, the compiler can make decision whether to: # - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed) # - keep it generic and optimize code size - when canUseNoCarryMontyMul: + when spareBits >= 1: when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) @@ -431,14 +431,14 @@ func montyMul*( montyMul_FIPS(r, a, b, M, m0ninv) func montySquare*(r: var Limbs, a, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} = + m0ninv: static BaseType, spareBits: static int) {.inline.} = ## Compute r <- a^2 (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 # TODO: needs optimization similar to multiplication - montyMul(r, a, a, M, m0ninv, canUseNoCarryMontySquare) + montyMul(r, a, a, M, m0ninv, spareBits) - # when canUseNoCarryMontySquare: + # when spareBits >= 2: # # TODO: Deactivated # # Off-by one on 32-bit on the least significant bit # # for Fp[BLS12-381] with inputs @@ -463,22 +463,22 @@ func montyRedc2x*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], - m0ninv: BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = + m0ninv: BaseType, spareBits: static int) {.inline.} = ## Montgomery reduce a double-precision bigint modulo M when UseASM_X86_64 and r.len <= 6: if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): - montRed_asm_adx_bmi2(r, a, M, m0ninv, canUseNoCarryMontyMul) + montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits) else: - montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) + montRed_asm(r, a, M, m0ninv, spareBits) elif UseASM_X86_32 and r.len <= 6: # TODO: Assembly faster than GCC but slower than Clang - montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) + montRed_asm(r, a, M, m0ninv, spareBits) else: montyRedc2x_CIOS(r, a, M, m0ninv) # montyRedc2x_Comba(r, a, M, m0ninv) func redc*(r: var Limbs, a, one, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = + m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## to the regular natural representation (mod N) ## @@ -497,10 +497,10 @@ func redc*(r: var Limbs, a, one, M: Limbs, # - http://langevin.univ-tln.fr/cours/MLC/extra/montgomery.pdf # Montgomery original paper # - montyMul(r, a, one, M, m0ninv, canUseNoCarryMontyMul) + montyMul(r, a, one, M, m0ninv, spareBits) func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = + m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's natural representation (mod N) ## to a the Montgomery n-residue representation ## @@ -518,7 +518,7 @@ func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, ## Important: `r` is overwritten ## The result `r` buffer size MUST be at least the size of `M` buffer # Reference: https://eprint.iacr.org/2017/1057.pdf - montyMul(r, a, r2ModM, M, m0ninv, canUseNoCarryMontyMul) + montyMul(r, a, r2ModM, M, m0ninv, spareBits) # Montgomery Modular Exponentiation # ------------------------------------------ @@ -565,7 +565,7 @@ func montyPowPrologue( a: var Limbs, M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool + spareBits: static int ): uint = ## Setup the scratchspace ## Returns the fixed-window size for exponentiation with window optimization. @@ -579,7 +579,7 @@ func montyPowPrologue( else: scratchspace[2] = a for k in 2 ..< 1 shl result: - scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, canUseNoCarryMontyMul) + scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, spareBits) # Set a to one a = one @@ -593,7 +593,7 @@ func montyPowSquarings( window: uint, acc, acc_len: var uint, e: var int, - canUseNoCarryMontySquare: static bool + spareBits: static int ): tuple[k, bits: uint] {.inline.}= ## Squaring step of exponentiation by squaring ## Get the next k bits in range [1, window) @@ -629,7 +629,7 @@ func montyPowSquarings( # We have k bits and can do k squaring for i in 0 ..< k: - tmp.montySquare(a, M, m0ninv, canUseNoCarryMontySquare) + tmp.montySquare(a, M, m0ninv, spareBits) a = tmp return (k, bits) @@ -640,8 +640,7 @@ func montyPow*( M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool, - canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Modular exponentiation r = a^exponent mod M ## in the Montgomery domain @@ -669,7 +668,7 @@ func montyPow*( ## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes ## of scratchspace (on the stack). - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) + let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) # We process bits with from most to least significant. # At each loop iteration with have acc_len bits in acc. @@ -684,7 +683,7 @@ func montyPow*( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, - canUseNoCarryMontySquare + spareBits ) # Window lookup: we set scratchspace[1] to the lookup value. @@ -699,7 +698,7 @@ func montyPow*( # Multiply with the looked-up value # we keep the product only if the exponent bits are not all zeroes - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) a.ccopy(scratchspace[0], SecretWord(bits).isNonZero()) func montyPowUnsafeExponent*( @@ -708,8 +707,7 @@ func montyPowUnsafeExponent*( M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool, - canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Modular exponentiation r = a^exponent mod M ## in the Montgomery domain @@ -723,7 +721,7 @@ func montyPowUnsafeExponent*( # TODO: scratchspace[1] is unused when window > 1 - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) + let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) var acc, acc_len: uint @@ -733,16 +731,16 @@ func montyPowUnsafeExponent*( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, - canUseNoCarryMontySquare + spareBits ) ## Warning ⚠️: Exposes the exponent bits if bits != 0: if window > 1: - scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, spareBits) else: # scratchspace[1] holds the original `a` - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) a = scratchspace[0] {.pop.} # raises no exceptions diff --git a/constantine/config/curves_derived.nim b/constantine/config/curves_derived.nim index 87530a76e..21ffa39d3 100644 --- a/constantine/config/curves_derived.nim +++ b/constantine/config/curves_derived.nim @@ -51,18 +51,10 @@ macro genDerivedConstants*(mode: static DerivedConstantMode): untyped = let M = if mode == kModulus: bindSym(curve & "_Modulus") else: bindSym(curve & "_Order") - # const MyCurve_CanUseNoCarryMontyMul = useNoCarryMontyMul(MyCurve_Modulus) + # const MyCurve_SpareBits = countSpareBits(MyCurve_Modulus) result.add newConstStmt( - used(curve & ff & "_CanUseNoCarryMontyMul"), newCall( - bindSym"useNoCarryMontyMul", - M - ) - ) - - # const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus) - result.add newConstStmt( - used(curve & ff & "_CanUseNoCarryMontySquare"), newCall( - bindSym"useNoCarryMontySquare", + used(curve & ff & "_SpareBits"), newCall( + bindSym"countSpareBits", M ) ) diff --git a/constantine/config/curves_prop_derived.nim b/constantine/config/curves_prop_derived.nim index 2439f4ebf..1bac65f77 100644 --- a/constantine/config/curves_prop_derived.nim +++ b/constantine/config/curves_prop_derived.nim @@ -61,15 +61,18 @@ template fieldMod*(Field: type FF): auto = else: Field.C.getCurveOrder() -macro canUseNoCarryMontyMul*(ff: type FF): untyped = - ## Returns true if the Modulus is compatible with a fast - ## Montgomery multiplication that avoids many carries - result = bindConstant(ff, "CanUseNoCarryMontyMul") - -macro canUseNoCarryMontySquare*(ff: type FF): untyped = - ## Returns true if the Modulus is compatible with a fast - ## Montgomery squaring that avoids many carries - result = bindConstant(ff, "CanUseNoCarryMontySquare") +macro getSpareBits*(ff: type FF): untyped = + ## Returns the number of extra bits + ## in the modulus M representation. + ## + ## This is used for no-carry operations + ## or lazily reduced operations by allowing + ## output in range: + ## - [0, 2p) if 1 bit is available + ## - [0, 4p) if 2 bits are available + ## - [0, 8p) if 3 bits are available + ## - ... + result = bindConstant(ff, "SpareBits") macro getR2modP*(ff: type FF): untyped = ## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus diff --git a/constantine/config/precompute.nim b/constantine/config/precompute.nim index a5cf9f5d2..7a478bcfd 100644 --- a/constantine/config/precompute.nim +++ b/constantine/config/precompute.nim @@ -238,21 +238,20 @@ func checkValidModulus(M: BigInt) = doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those" -func useNoCarryMontyMul*(M: BigInt): bool = - ## Returns if the modulus is compatible - ## with the no-carry Montgomery Multiplication - ## from https://hackmd.io/@zkteam/modular_multiplication - # Indirection needed because static object are buggy - # https://github.com/nim-lang/Nim/issues/9679 - BaseType(M.limbs[^1]) < high(BaseType) shr 1 - -func useNoCarryMontySquare*(M: BigInt): bool = - ## Returns if the modulus is compatible - ## with the no-carry Montgomery Squaring - ## from https://hackmd.io/@zkteam/modular_multiplication - # Indirection needed because static object are buggy - # https://github.com/nim-lang/Nim/issues/9679 - BaseType(M.limbs[^1]) < high(BaseType) shr 2 +func countSpareBits*(M: BigInt): int = + ## Count the number of extra bits + ## in the modulus M representation. + ## + ## This is used for no-carry operations + ## or lazily reduced operations by allowing + ## output in range: + ## - [0, 2p) if 1 bit is available + ## - [0, 4p) if 2 bits are available + ## - [0, 8p) if 3 bits are available + ## - ... + checkValidModulus(M) + let msb = log2(BaseType(M.limbs[^1])) + result = WordBitWidth - 1 - msb.int func invModBitwidth[T: SomeUnsignedInt](a: T): T = # We use BaseType for return value because static distinct type diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index b02a1d76d..65b07a380 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -333,11 +333,11 @@ template doublePrec(T: type ExtensionField): type = func has1extraBit(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) - canUseNoCarryMontyMul(Fp[E.C]) + getSpareBits(Fp[E.C]) >= 1 func has2extraBits(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) - canUseNoCarryMontySquare(Fp[E.C]) + getSpareBits(Fp[E.C]) >= 2 template C(E: type ExtensionField2x): Curve = E.F.C diff --git a/helpers/prng_unsafe.nim b/helpers/prng_unsafe.nim index 38740ac48..80b4a99a6 100644 --- a/helpers/prng_unsafe.nim +++ b/helpers/prng_unsafe.nim @@ -146,7 +146,7 @@ func random_unsafe(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_unsafe(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -177,7 +177,7 @@ func random_highHammingWeight(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_highHammingWeight(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -222,7 +222,7 @@ func random_long01Seq(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_long01Seq(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element diff --git a/tests/t_finite_fields_mulsquare.nim b/tests/t_finite_fields_mulsquare.nim index 49cd86bda..722cef424 100644 --- a/tests/t_finite_fields_mulsquare.nim +++ b/tests/t_finite_fields_mulsquare.nim @@ -27,7 +27,7 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" proc sanity(C: static Curve) = - test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $Fp[C].canUseNoCarryMontySquare & "]": + test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $(Fp[C].getSpareBits() >= 2) & "]": block: # 0² mod var n: Fp[C] @@ -89,7 +89,7 @@ mainSanity() proc mainSelectCases() = suite "Modular Squaring: selected tricky cases" & " [" & $WordBitwidth & "-bit mode]": - test "P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": + test "P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]": block: # Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0) # between the squaring and reduction step, with t[N+1] and A1 being carry bits. @@ -136,7 +136,7 @@ proc random_long01Seq(C: static Curve) = doAssert bool(r_mul == r_sqr) suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": - test "Random squaring mod P-224 [FastSquaring = " & $Fp[P224].canUseNoCarryMontySquare & "]": + test "Random squaring mod P-224 [FastSquaring = " & $(Fp[P224].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(P224) for _ in 0 ..< Iters: @@ -144,7 +144,8 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(P224) - test "Random squaring mod P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": + test "Random squaring mod P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]": + echo "Fp[P256].getSpareBits(): ", Fp[P256].getSpareBits() for _ in 0 ..< Iters: randomCurve(P256) for _ in 0 ..< Iters: @@ -152,7 +153,7 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(P256) - test "Random squaring mod BLS12_381 [FastSquaring = " & $Fp[BLS12_381].canUseNoCarryMontySquare & "]": + test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BLS12_381) for _ in 0 ..< Iters: diff --git a/tests/t_fr.nim b/tests/t_fr.nim index d6eb349a4..de1e47a4e 100644 --- a/tests/t_fr.nim +++ b/tests/t_fr.nim @@ -25,7 +25,7 @@ echo "\n------------------------------------------------------\n" echo "test_fr xoshiro512** seed: ", seed proc sanity(C: static Curve) = - test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $Fr[C].canUseNoCarryMontySquare & "]": + test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $(Fr[C].getSpareBits() >= 2) & "]": block: # 0² mod var n: Fr[C] @@ -112,7 +112,7 @@ proc random_long01Seq(C: static Curve) = doAssert bool(r_mul == r_sqr) suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": - test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $Fr[BN254_Snarks].canUseNoCarryMontySquare & "]": + test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $(Fr[BN254_Snarks].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BN254_Snarks) for _ in 0 ..< Iters: @@ -120,7 +120,7 @@ suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & for _ in 0 ..< Iters: random_long01Seq(BN254_Snarks) - test "Random squaring mod r_BLS12_381 [FastSquaring = " & $Fr[BLS12_381].canUseNoCarryMontySquare & "]": + test "Random squaring mod r_BLS12_381 [FastSquaring = " & $(Fr[BLS12_381].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BLS12_381) for _ in 0 ..< Iters: From a26587277a619e6183bb0e215addc13fe758af70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 23:03:55 +0100 Subject: [PATCH 15/22] Complete the double-precision implementation --- .../finite_fields_double_precision.nim | 82 +++++++++++--- tests/t_finite_fields_double_precision.nim | 100 +++++++++++++++++- 2 files changed, 164 insertions(+), 18 deletions(-) diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim index 9b1d1c432..186e9e781 100644 --- a/constantine/arithmetic/finite_fields_double_precision.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -41,6 +41,9 @@ template doublePrec*(T: type Fp): type = func `==`*(a, b: FpDbl): SecretBool = a.limbs2x == b.limbs2x +func isZero*(a: FpDbl): SecretBool = + a.limbs2x.isZero() + func setZero*(a: var FpDbl) = a.limbs2x.setZero() @@ -92,13 +95,15 @@ func diff2xMod*(r: var FpDbl, a, b: FpDbl) = when UseASM_X86_64: submod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) else: + # Substraction step var underflowed = SecretBool r.limbs2x.diff(a.limbs2x, b.limbs2x) + # Conditional reduction by 2ⁿp const N = r.limbs2x.len div 2 const M = FpDbl.C.Mod var carry = Carry(0) var sum: SecretWord - for i in 0 ..< N: + staticFor i, 0, N: addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) underflowed.ccopy(r.limbs2x[i+N], sum) @@ -110,32 +115,63 @@ func sum2xUnr*(r: var FpDbl, a, b: FpDbl) = discard r.limbs2x.sum(a.limbs2x, b.limbs2x) func sum2xMod*(r: var FpDbl, a, b: FpDbl) = - ## Double-precision modular substraction + ## Double-precision modular addition ## Output is conditionally reduced by 2ⁿp ## to stay in the [0, 2ⁿp) range - when UseASM_X86_64: + when false: addmod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) else: + # Addition step var overflowed = SecretBool r.limbs2x.sum(a.limbs2x, b.limbs2x) const N = r.limbs2x.len div 2 const M = FpDbl.C.Mod + # Test >= 2ⁿp var borrow = Borrow(0) - var diff: SecretWord - for i in 0 ..< N: - subB(borrow, diff, r.limbs2x[i+N], M.limbs[i], borrow) - overflowed.ccopy(r.limbs2x[i+N], diff) + var t{.noInit.}: Limbs[N] + staticFor i, 0, N: + subB(borrow, t[i], r.limbs2x[i+N], M.limbs[i], borrow) + + # If no borrow occured, r was bigger than 2ⁿp + overflowed = overflowed or not(SecretBool borrow) -func prod2x*( + # Conditional reduction by 2ⁿp + staticFor i, 0, N: + SecretBool(overflowed).ccopy(r.limbs2x[i+N], t[i]) + +func neg2xMod*(r: var FpDbl, a: FpDbl) = + ## Double-precision modular substraction + ## Negate modulo 2ⁿp + when false: # TODO UseASM_X86_64: + {.error: "Implement assembly".} + else: + # If a = 0 we need r = 0 and not r = M + # as comparison operator assume unicity + # of the modular representation. + # Also make sure to handle aliasing where r.addr = a.addr + var t {.noInit.}: FpDbl + let isZero = a.isZero() + const N = r.limbs2x.len div 2 + const M = FpDbl.C.Mod + var borrow = Borrow(0) + # 2ⁿp is filled with 0 in the first half + staticFor i, 0, N: + subB(borrow, t.limbs2x[i], Zero, a.limbs2x[i], borrow) + # 2ⁿp has p (shifted) for the rest of the limbs + staticFor i, N, r.limbs2x.len: + subB(borrow, t.limbs2x[i], M.limbs[i-N], a.limbs2x[i], borrow) + + # Zero the result if input was zero + t.limbs2x.czero(isZero) + r = t + +func prod2xImpl( r {.noAlias.}: var FpDbl, a {.noAlias.}: FpDbl, b: static int) = ## Multiplication by a small integer known at compile-time - ## Requires no aliasing - const negate = b < 0 - const b = if negate: -b - else: b - when negate: - r.diff2xMod(typeof(a)(), a) + ## Requires no aliasing and b positive + static: doAssert b >= 0 + when b == 0: r.setZero() elif b == 1: @@ -172,18 +208,18 @@ func prod2x*( r.sum2xMod(r, r) # 8 r.sum2xMod(r, a) elif b == 10: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) r.sum2xMod(r, a) # 5 r.sum2xMod(r, r) elif b == 11: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) r.sum2xMod(r, a) # 5 r.sum2xMod(r, r) r.sum2xMod(r, a) elif b == 12: - a.sum2xMod(a, a) + r.sum2xMod(a, a) r.sum2xMod(r, r) # 4 let t4 = a r.sum2xMod(r, r) # 8 @@ -191,5 +227,17 @@ func prod2x*( else: {.error: "Multiplication by this small int not implemented".} +func prod2x*(r: var FpDbl, a: FpDbl, b: static int) = + ## Multiplication by a small integer known at compile-time + const negate = b < 0 + const b = if negate: -b + else: b + when negate: + var t {.noInit.}: typeof(r) + t.neg2xMod(a) + else: + let t = a + prod2xImpl(r, t, b) + {.pop.} # inline {.pop.} # raises no exceptions diff --git a/tests/t_finite_fields_double_precision.nim b/tests/t_finite_fields_double_precision.nim index 942e6d421..e52f25156 100644 --- a/tests/t_finite_fields_double_precision.nim +++ b/tests/t_finite_fields_double_precision.nim @@ -24,10 +24,72 @@ rng.seed(seed) echo "\n------------------------------------------------------\n" echo "test_finite_fields_double_precision xoshiro512** seed: ", seed +template addsubnegTest(rng_gen: untyped): untyped = + proc `addsubneg _ rng_gen`(C: static Curve) = + # Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg + # (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1 + let aFp = rng_gen(rng, Fp[C]) + let bFp = rng_gen(rng, Fp[C]) + var accumFp {.noInit.}: Fp[C] + var OneFp {.noInit.}: Fp[C] + var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: FpDbl[C] + + OneFp.setOne() + One.prod2x(OneFp, OneFp) + a.prod2x(aFp, OneFp) + b.prod2x(bFp, OneFp) + + block: # sanity check + var t: Fp[C] + t.redc2x(One) + doAssert bool t.isOne() + + a2.sum2xMod(a, a) + na.neg2xMod(a) + + block: # sanity check + var t0, t1: Fp[C] + t0.redc2x(na) + t1.neg(aFp) + doAssert bool(t0 == t1), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " t0: " & t0.toHex() & "\n" & + " t1: " & t1.toHex() & "\n" + + block: # sanity check + var t0, t1: Fp[C] + t0.redc2x(a2) + t1.double(aFp) + doAssert bool(t0 == t1), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " t0: " & t0.toHex() & "\n" & + " t1: " & t1.toHex() & "\n" + + b2.sum2xMod(b, b) + nb.neg2xMod(b) + + accum.diff2xMod(One, na) + accum.diff2xMod(accum, b) + accum.sum2xMod(accum, na) + accum.diff2xMod(accum, a2) + + var t{.noInit.}: FpDbl[C] + t.sum2xMod(a2, b2) + t.sum2xMod(t, nb) + + accum.sum2xMod(accum, t) + accumFp.redc2x(accum) + doAssert bool accumFp.isOne(), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " accumFp: " & accumFp.toHex() + template mulTest(rng_gen: untyped): untyped = proc `mul _ rng_gen`(C: static Curve) = let a = rng_gen(rng, Fp[C]) - let b = rng.random_unsafe(Fp[C]) + let b = rng_gen(rng, Fp[C]) var r_fp{.noInit.}, r_fpDbl{.noInit.}: Fp[C] var tmpDbl{.noInit.}: FpDbl[C] @@ -49,6 +111,9 @@ template sqrTest(rng_gen: untyped): untyped = doAssert bool(mulDbl == sqrDbl) +addsubnegTest(random_unsafe) +addsubnegTest(randomHighHammingWeight) +addsubnegTest(random_long01Seq) mulTest(random_unsafe) mulTest(randomHighHammingWeight) mulTest(random_long01Seq) @@ -56,6 +121,39 @@ sqrTest(random_unsafe) sqrTest(randomHighHammingWeight) sqrTest(random_long01Seq) +suite "Field Addition/Substraction/Negation via double-precision field elements" & " [" & $WordBitwidth & "-bit mode]": + test "With P-224 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(P224) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(P224) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(P224) + + test "With P-256 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(P256) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(P256) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(P256) + + test "With BN254_Snarks field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(BN254_Snarks) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(BN254_Snarks) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(BN254_Snarks) + + test "With BLS12_381 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(BLS12_381) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(BLS12_381) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(BLS12_381) + suite "Field Multiplication via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": test "With P-224 field modulus": for _ in 0 ..< Iters: From f5eb4e21b66695cf47a4808759798199330bad29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 23:23:16 +0100 Subject: [PATCH 16/22] Use double-precision path for Fp4 squaring and mul --- .../extension_fields.nim | 270 ++++++------------ 1 file changed, 83 insertions(+), 187 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 65b07a380..947231533 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -222,82 +222,12 @@ func csub*(a: var ExtensionField, b: ExtensionField, ctl: SecretBool) = func `*=`*(a: var ExtensionField, b: static int) = ## Multiplication by a small integer known at compile-time - - const negate = b < 0 - const b = if negate: -b - else: b - when negate: - a.neg(a) - when b == 0: - a.setZero() - elif b == 1: - return - elif b == 2: - a.double() - elif b == 3: - var t {.noInit.}: typeof(a) - t.double(a) - a += t - elif b == 4: - a.double() - a.double() - elif b == 5: - var t {.noInit.}: typeof(a) - t.double(a) - t.double() - a += t - elif b == 6: - var t {.noInit.}: typeof(a) - t.double(a) - t += a # 3 - a.double(t) - elif b == 7: - var t {.noInit.}: typeof(a) - t.double(a) - t.double() - t.double() - a.diff(t, a) - elif b == 8: - a.double() - a.double() - a.double() - elif b == 9: - var t {.noInit.}: typeof(a) - t.double(a) - t.double() - t.double() - a.sum(t, a) - elif b == 10: - var t {.noInit.}: typeof(a) - t.double(a) - t.double() - a += t # 5 - a.double() - elif b == 11: - var t {.noInit.}: typeof(a) - t.double(a) - t += a # 3 - t.double() # 6 - t.double() # 12 - a.diff(t, a) # 11 - elif b == 12: - var t {.noInit.}: typeof(a) - t.double(a) - t += a # 3 - t.double() # 6 - a.double(t) # 12 - else: - {.error: "Multiplication by this small int not implemented".} + for i in 0 ..< a.coords.len: + a.coords[i] *= b func prod*(r: var ExtensionField, a: ExtensionField, b: static int) = ## Multiplication by a small integer known at compile-time - const negate = b < 0 - const b = if negate: -b - else: b - when negate: - r.neg(a) - else: - r = a + r = a r *= b {.pop.} # inline @@ -331,13 +261,21 @@ template doublePrec(T: type ExtensionField): type = when T.F is QuadraticExt: # Fp6Dbl CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]] +func has1extraBit(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 1 + +func has2extraBits(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 2 + func has1extraBit(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) - getSpareBits(Fp[E.C]) >= 1 + getSpareBits(Fp[E.F.C]) >= 1 func has2extraBits(E: type ExtensionField): bool = ## We construct extensions only on Fp (and not Fr) - getSpareBits(Fp[E.C]) >= 2 + getSpareBits(Fp[E.F.C]) >= 2 template C(E: type ExtensionField2x): Curve = E.F.C @@ -392,6 +330,11 @@ func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = staticFor i, 0, a.coords.len: r.coords[i].sum2xMod(a.coords[i], b.coords[i]) +func neg2xMod(r: var ExtensionField2x, a: ExtensionField2x) = + ## Double-precision modular negation + staticFor i, 0, a.coords.len: + r.coords[i].neg2xMod(a.coords[i], b.coords[i]) + # Reductions # ------------------------------------------------------------------- @@ -403,71 +346,10 @@ func redc2x(r: var ExtensionField, a: ExtensionField2x) = # Multiplication by a small integer known at compile-time # ------------------------------------------------------------------- -func prod( - r {.noAlias.}: var ExtensionField2x, - a {.noAlias.}: ExtensionField2x, b: static int) = +func prod(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = ## Multiplication by a small integer known at compile-time - ## Requires no aliasing - const negate = b < 0 - const b = if negate: -b - else: b - - when negate: - r.diff2xMod(typeof(a)(), a) - when b == 0: - r.setZero() - elif b == 1: - r = a - elif b == 2: - r.sum2xMod(a, a) - elif b == 3: - r.sum2xMod(a, a) - r.sum2xMod(a, r) - elif b == 4: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - elif b == 5: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - r.sum2xMod(r, a) - elif b == 6: - r.sum2xMod(a, a) - let t2 = r - r.sum2xMod(r, r) # 4 - r.sum2xMod(t, t2) - elif b == 7: - r.sum2xMod(a, a) - r.sum2xMod(r, r) # 4 - r.sum2xMod(r, r) - r.diff2xMod(r, a) - elif b == 8: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - r.sum2xMod(r, r) - elif b == 9: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - r.sum2xMod(r, r) # 8 - r.sum2xMod(r, a) - elif b == 10: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - r.sum2xMod(r, a) # 5 - r.sum2xMod(r, r) - elif b == 11: - r.sum2xMod(a, a) - r.sum2xMod(r, r) - r.sum2xMod(r, a) # 5 - r.sum2xMod(r, r) - r.sum2xMod(r, a) - elif b == 12: - r.sum2xMod(a, a) - r.sum2xMod(r, r) # 4 - let t4 = a - r.sum2xMod(r, r) # 8 - r.sum2xMod(r, t4) - else: - {.error: "Multiplication by this small int not implemented".} + for i in 0 ..< a.coords.len: + a *= b # NonResidue # ---------------------------------------------------------------------- @@ -493,22 +375,35 @@ func prod2x[C: static Curve]( r.c0.diff2xMod(a.c0, a.c1) r.c1.sum2xMod(a.c0, a.c1) else: - # ξ = u + v x - # and x² = β - # - # (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x² - # => u c0 + β v c1 + (v c0 + u c1) x - var t {.noInit.}: FpDbl[C] - - r.c0.prod2x(a.c0, U) - when V == 1 and Beta == -1: # Case BN254_Snarks - r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1 + # Case: + # - BN254_Snarks, QNR_Fp: -1, SNR_Fp2: 9+1𝑖 (𝑖 = √-1) + # - BLS12_377, QNR_Fp: -5, SNR_Fp2: 0+1j (j = √-5) + # - BW6_761, SNR_Fp: -4, CNR_Fp2: 0+1j (j = √-4) + when U == 0: + # mul_sparse_by_0v + # r0 = β a1 v + # r1 = a0 v + # r and a don't alias, we use `r` as a temp location + r.c1.prod2x(a.c1, V) + r.c0.prod2x(r.c1, NonResidue) + r.c1.prod2x(a.c0, V) else: - {.error: "Unimplemented".} + # ξ = u + v x + # and x² = β + # + # (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x² + # => u c0 + β v c1 + (v c0 + u c1) x + var t {.noInit.}: FpDbl[C] + + r.c0.prod2x(a.c0, U) + when V == 1 and Beta == -1: # Case BN254_Snarks + r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1 + else: + {.error: "Unimplemented".} - r.c1.prod2x(a.c0, V) - t.prod2x(a.c1, U) - r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1 + r.c1.prod2x(a.c0, V) + t.prod2x(a.c1, U) + r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1 # ############################################################ # # @@ -516,6 +411,12 @@ func prod2x[C: static Curve]( # # # ############################################################ +# Forward declarations +# ---------------------------------------------------------------------- + +func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) +func square2x(r: var QuadraticExt2x, a: QuadraticExt) + # Commutative ring implementation for complex quadratic extension fields # ---------------------------------------------------------------------- @@ -566,24 +467,6 @@ func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = t0.diff(a.c0, a.c1) r.c0.prod2x(t0, t1) # r0 = (a0 + a1)(a0 - a1) -func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = - mixin fromComplexExtension - when a.fromComplexExtension(): - r.prod2x_complex(a, b) - else: - {.error: "Not Implementated".} - -func square2x_disjoint[Fdbl, F]( - r: var QuadraticExt2x[FDbl], - a0, a1: F) - -func square2x(r: var QuadraticExt2x, a: QuadraticExt) = - mixin fromComplexExtension, square2x_disjoint - when a.fromComplexExtension(): - r.square2x_complex(a) - else: - r.square2x_disjoint(a.c0, a.c1) - # Commutative ring implementation for generic quadratic extension fields # ---------------------------------------------------------------------- # @@ -609,22 +492,17 @@ func prod2x_disjoint[Fdbl, F]( var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width # Require 2 extra bits - V0.prod2x(a.c0, b0) - V1.prod2x(a.c1, b1) - when F.has2extraBits(): - t0.sumUnr(b0, b1) - t1.sumUnr(a.c0, a.c1) - else: - t0.sum(b0, b1) - t1.sum(a.c0, a.c1) + V0.prod2x(a.c0, b0) # v0 = a0b0 + V1.prod2x(a.c1, b1) # v1 = a1b1 + t0.sum(a.c0, a.c1) + t1.sum(b0, b1) - r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) , and at most 2 extra bits - r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 , and at most 1 extra bit - r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1, and 0 extra bit + r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) + r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 + r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 - # TODO: This is correct modulo p, but we have some extra bits still here. - r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 - r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 + r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 + r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 func square2x_disjoint[Fdbl, F]( r: var QuadraticExt2x[FDbl], @@ -649,6 +527,23 @@ func square2x_disjoint[Fdbl, F]( r.c1.diff2xMod(r.c1, V0) r.c1.diff2xMod(r.c1, V1) +# Dispatch +# ---------------------------------------------------------------------- + +func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = + mixin fromComplexExtension + when a.fromComplexExtension(): + r.prod2x_complex(a, b) + else: + r.prod2x_disjoint(a, b.c0, b.c1) + +func square2x(r: var QuadraticExt2x, a: QuadraticExt) = + mixin fromComplexExtension + when a.fromComplexExtension(): + r.square2x_complex(a) + else: + r.square2x_disjoint(a.c0, a.c1) + # ############################################################ # # # Cubic extensions - Lazy Reductions # @@ -1041,9 +936,10 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = r.c0.redc2x(d.c0) r.c1.redc2x(d.c1) else: - when true: # typeof(r.c0) is Fp: + when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp: + # BW6-761 requires too many registers for Dbl width path r.prod_generic(a, b) - else: # Deactivated r.c0 is correct modulo p but >= p + else: var d {.noInit.}: doublePrec(typeof(r)) d.prod2x_disjoint(a, b.c0, b.c1) r.c0.redc2x(d.c0) From 0fe1f2b4fbb48da9e3173d2452231810a43fff2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 23:28:46 +0100 Subject: [PATCH 17/22] remove mixin annotations --- constantine/tower_field_extensions/extension_fields.nim | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 947231533..ce27f982b 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -346,10 +346,10 @@ func redc2x(r: var ExtensionField, a: ExtensionField2x) = # Multiplication by a small integer known at compile-time # ------------------------------------------------------------------- -func prod(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = +func prod2x(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = ## Multiplication by a small integer known at compile-time for i in 0 ..< a.coords.len: - a *= b + r.coords[i].prod2x(a.coords[i], b) # NonResidue # ---------------------------------------------------------------------- @@ -739,7 +739,6 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) = r.c1 -= v1 else: - mixin prod var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) # v1 <- (c0 + β c1) @@ -771,7 +770,6 @@ func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = # # r0 = a0 b0 + β a1 b1 # r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba) - mixin prod var v0 {.noInit.}, v1 {.noInit.}, v2 {.noInit.}: typeof(r.c0) # v2 <- (a0 + a1)(b0 + b1) @@ -806,7 +804,6 @@ func mul_sparse_generic_by_x0(r: var QuadraticExt, a, sparseB: QuadraticExt) = # # r0 = a0 b0 # r1 = (a0 + a1) b0 - a0 b0 = a1 b0 - mixin prod template b(): untyped = sparseB r.c0.prod(a.c0, b.c0) @@ -1038,7 +1035,6 @@ func mul_sparse_by_x0*(a: var QuadraticExt, sparseB: QuadraticExt) = func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}= ## Returns r = a² - mixin prod, square, sum var s0{.noInit.}, m01{.noInit.}, m12{.noInit.}: typeof(r.c0) # precomputations that use a @@ -1074,7 +1070,6 @@ func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}= func square_Chung_Hasan_SQR3(r: var CubicExt, a: CubicExt) = ## Returns r = a² - mixin prod, square, sum var s0{.noInit.}, t{.noInit.}, m12{.noInit.}: typeof(r.c0) # s₀ = (a₀ + a₁ + a₂)² From 234147bf0f184cac79c219e902ae413ca9d889c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Mon, 8 Feb 2021 23:40:49 +0100 Subject: [PATCH 18/22] Lazy reduction in Fp4 prod --- .../tower_field_extensions/extension_fields.nim | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index ce27f982b..1d7dafcf0 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -494,12 +494,20 @@ func prod2x_disjoint[Fdbl, F]( # Require 2 extra bits V0.prod2x(a.c0, b0) # v0 = a0b0 V1.prod2x(a.c1, b1) # v1 = a1b1 - t0.sum(a.c0, a.c1) - t1.sum(b0, b1) + when F.has1extraBit(): + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b0, b1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b0, b1) r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) - r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 - r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 + when F.has1extraBit(): + r.c1.diff2xMod(r.c1, V0) + r.c1.diff2xMod(r.c1, V1) + else: + r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 + r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 From 0df6bb0431b10ec6e9942f3a9f046b9c61d60825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Tue, 9 Feb 2021 18:38:33 +0100 Subject: [PATCH 19/22] Fix assembly for sum2xMod --- .../assembly/limbs_asm_modular_dbl_prec_x86.nim | 12 ++++-------- .../arithmetic/assembly/limbs_asm_modular_x86.nim | 2 +- .../arithmetic/finite_fields_double_precision.nim | 2 +- docs/optimizations.md | 8 ++++---- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim index 322e219b2..242c2454a 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -67,12 +67,10 @@ macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N di ctx.add u[0], b[0] else: ctx.adc u[i], b[i] + ctx.mov r[i], u[i] - # Everything should be hot in cache now so movs are cheaper - # we can try using 2 per ADC # v = a[H.. Date: Tue, 9 Feb 2021 18:58:11 +0100 Subject: [PATCH 20/22] Assembly for double-precision negation --- benchmarks/bench_blueprint.nim | 1 - benchmarks/bench_fp_double_precision.nim | 22 +++++- .../limbs_asm_modular_dbl_prec_x86.nim | 77 ++++++++++++++++++- .../arithmetic/assembly/limbs_asm_x86.nim | 4 +- .../finite_fields_double_precision.nim | 4 +- tests/t_finite_fields_double_precision.nim | 7 ++ 6 files changed, 104 insertions(+), 11 deletions(-) diff --git a/benchmarks/bench_blueprint.nim b/benchmarks/bench_blueprint.nim index ce3c10e68..eabdd8b3b 100644 --- a/benchmarks/bench_blueprint.nim +++ b/benchmarks/bench_blueprint.nim @@ -88,7 +88,6 @@ proc notes*() = echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"." echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"." echo " - The simplest operations might be optimized away by the compiler." - echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)" template measure*(iters: int, startTime, stopTime: untyped, diff --git a/benchmarks/bench_fp_double_precision.nim b/benchmarks/bench_fp_double_precision.nim index db4a4f25d..ba05464df 100644 --- a/benchmarks/bench_fp_double_precision.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -89,11 +89,13 @@ proc notes*() = echo "Notes:" echo " - Compilers:" echo " Compilers are severely limited on multiprecision arithmetic." - echo " Inline Assembly is used by default (nimble bench_fp)." - echo " Bench without assembly can use \"nimble bench_fp_gcc\" or \"nimble bench_fp_clang\"." + echo " Constantine compile-time assembler is used by default (nimble bench_fp)." echo " GCC is significantly slower than Clang on multiprecision arithmetic due to catastrophic handling of carries." + echo " GCC also seems to have issues with large temporaries and register spilling." + echo " This is somewhat alleviated by Constantine compile-time assembler." + echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"." + echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"." echo " - The simplest operations might be optimized away by the compiler." - echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)" template bench(op: string, desc: string, iters: int, body: untyped): untyped = let start = getMonotime() @@ -149,6 +151,12 @@ proc diff(T: typedesc, iters: int) = bench("Substraction", $T, iters): r.diff(a, b) +proc neg(T: typedesc, iters: int) = + var r: T + let a = rng.random_unsafe(T) + bench("Negation", $T, iters): + r.neg(a) + proc sum2xUnreduce(T: typedesc, iters: int) = var r, a, b: doublePrec(T) rng.random_unsafe(r, T) @@ -181,6 +189,12 @@ proc diff2x(T: typedesc, iters: int) = bench("Substraction 2x reduced", $doublePrec(T), iters): r.diff2xMod(a, b) +proc neg2x(T: typedesc, iters: int) = + var r, a: doublePrec(T) + rng.random_unsafe(a, T) + bench("Negation 2x reduced", $doublePrec(T), iters): + r.neg2xMod(a) + proc prod2xBench*(rLen, aLen, bLen: static int, iters: int) = var r: BigInt[rLen] let a = rng.random_unsafe(BigInt[aLen]) @@ -208,11 +222,13 @@ proc main() = sumUnr(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000) diffUnr(Fp[BLS12_381], iters = 10_000_000) + neg(Fp[BLS12_381], iters = 10_000_000) separator() sum2x(Fp[BLS12_381], iters = 10_000_000) sum2xUnreduce(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000) diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000) + neg2x(Fp[BLS12_381], iters = 10_000_000) separator() prod2xBench(768, 384, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000) diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim index 242c2454a..c19d4e94b 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -29,10 +29,9 @@ import static: doAssert UseASM_X86_64 {.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers -# Field addition +# Double-precision field addition # ------------------------------------------------------------ - macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = ## Generate an optimized out-of-place double-precision addition kernel @@ -104,7 +103,7 @@ func addmod2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N di ## to stay in the [0, 2ⁿp) range addmod2x_gen(r, a, b, M) -# Field Substraction +# Double-precision field substraction # ------------------------------------------------------------ macro submod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = @@ -171,3 +170,75 @@ func submod2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N di ## Output is conditionally reduced by 2ⁿp ## to stay in the [0, 2ⁿp) range submod2x_gen(r, a, b, M) + +# Double-precision field negation +# ------------------------------------------------------------ + +macro negmod2x_gen[N: static int](R: var Limbs[N], A: Limbs[N], m: Limbs[N div 2]): untyped = + ## Generate an optimized modular negation kernel + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + H = N div 2 + + a = init(OperandArray, nimSymbol = A, N, PointerInReg, Input) + r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput) + u = init(OperandArray, nimSymbol = ident"U", N, ElemsInReg, Output_EarlyClobber) + # We could force m as immediate by specializing per moduli + # We reuse the reg used for m for overflow detection + M = init(OperandArray, nimSymbol = m, N, PointerInReg, InputOutput) + + isZero = Operand( + desc: OperandDesc( + asmId: "[isZero]", + nimSymbol: ident"isZero", + rm: Reg, + constraint: Output_EarlyClobber, + cEmit: "isZero" + ) + ) + + # Substraction 2ⁿp - a + # The lower half of 2ⁿp is filled with zero + ctx.`xor` isZero, isZero + for i in 0 ..< H: + ctx.`xor` u[i], u[i] + ctx.`or` isZero, a[i] + + for i in 0 ..< H: + # 0 - a[i] + if i == 0: + ctx.sub u[0], a[0] + else: + ctx.sbb u[i], a[i] + # store result, overwrite a[i] lower-half if aliasing. + ctx.mov r[i], u[i] + # Prepare second-half, u <- M + ctx.mov u[i], M[i] + + for i in H ..< N: + # u = 2ⁿp higher half + ctx.sbb u[i-H], a[i] + + # Deal with a == 0, + # we already accumulated 0 in the first half (which was destroyed if aliasing) + for i in H ..< N: + ctx.`or` isZero, a[i] + + # Zero result if a == 0, only the upper half needs to be zero-ed here + for i in H ..< N: + ctx.cmovz u[i-H], isZero + ctx.mov r[i], u[i-H] + + let isZerosym = isZero.desc.nimSymbol + let usym = u.nimSymbol + result.add quote do: + var `isZerosym`{.noInit.}: BaseType + var `usym`{.noinit.}: typeof(`A`) + result.add ctx.generate + +func negmod2x_asm*[N: static int](r: var Limbs[N], a: Limbs[N], M: Limbs[N div 2]) = + ## Constant-time double-precision negation + negmod2x_gen(r, a, M) diff --git a/constantine/arithmetic/assembly/limbs_asm_x86.nim b/constantine/arithmetic/assembly/limbs_asm_x86.nim index 8321bcdfc..105bd2ecd 100644 --- a/constantine/arithmetic/assembly/limbs_asm_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_x86.nim @@ -143,7 +143,7 @@ macro add_gen[N: static int](carry: var Carry, r: var Limbs[N], a, b: Limbs[N]): for i in 1 ..< N: ctx.mov t1, arrA[i] # Prepare the next iteration - ctx.mov arrR[i-1], t0 # Save the previous reult in an interleaved manner + ctx.mov arrR[i-1], t0 # Save the previous result in an interleaved manner ctx.adc t1, arrB[i] # Compute swap(t0, t1) # Break dependency chain @@ -210,7 +210,7 @@ macro sub_gen[N: static int](borrow: var Borrow, r: var Limbs[N], a, b: Limbs[N] ctx.mov arrR[N-1], t0 # Epilogue ctx.setToCarryFlag(borrow) - + # Codegen result.add ctx.generate diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim index 5b3e425ce..f3eba513f 100644 --- a/constantine/arithmetic/finite_fields_double_precision.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -142,8 +142,8 @@ func sum2xMod*(r: var FpDbl, a, b: FpDbl) = func neg2xMod*(r: var FpDbl, a: FpDbl) = ## Double-precision modular substraction ## Negate modulo 2ⁿp - when false: # TODO UseASM_X86_64: - {.error: "Implement assembly".} + when UseASM_X86_64: + negmod2x_asm(r.limbs2x, a.limbs2x, FpDbl.C.Mod.limbs) else: # If a = 0 we need r = 0 and not r = M # as comparison operator assume unicity diff --git a/tests/t_finite_fields_double_precision.nim b/tests/t_finite_fields_double_precision.nim index e52f25156..60eaf5b58 100644 --- a/tests/t_finite_fields_double_precision.nim +++ b/tests/t_finite_fields_double_precision.nim @@ -154,6 +154,13 @@ suite "Field Addition/Substraction/Negation via double-precision field elements" for _ in 0 ..< Iters: addsubneg_random_long01Seq(BLS12_381) + test "Negate 0 returns 0 (unique Montgomery repr)": + var a: FpDbl[BN254_Snarks] + var r {.noInit.}: FpDbl[BN254_Snarks] + r.neg2xMod(a) + + check: bool r.isZero() + suite "Field Multiplication via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": test "With P-224 field modulus": for _ in 0 ..< Iters: From 427143c6c0e50954a6bbdff1dd17de95334caa39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Tue, 9 Feb 2021 19:01:42 +0100 Subject: [PATCH 21/22] reduce white spaces in pairing benchmarks --- benchmarks/bench_pairing_template.nim | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_pairing_template.nim b/benchmarks/bench_pairing_template.nim index c327c8bc2..7096ae1db 100644 --- a/benchmarks/bench_pairing_template.nim +++ b/benchmarks/bench_pairing_template.nim @@ -32,15 +32,15 @@ import ./bench_blueprint export notes -proc separator*() = separator(177) +proc separator*() = separator(132) proc report(op, curve: string, startTime, stopTime: MonoTime, startClk, stopClk: int64, iters: int) = let ns = inNanoseconds((stopTime-startTime) div iters) let throughput = 1e9 / float64(ns) when SupportsGetTicks: - echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" + echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" else: - echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op" + echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op" template bench(op: string, C: static Curve, iters: int, body: untyped): untyped = measure(iters, startTime, stopTime, startClk, stopClk, body) From 052f4bf3a73b6a659c8ef24f0ee46cf683e30cc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Tue, 9 Feb 2021 19:02:51 +0100 Subject: [PATCH 22/22] ADX implies BMI2 --- constantine/arithmetic/limbs_extmul.nim | 3 ++- constantine/arithmetic/limbs_montgomery.nim | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/constantine/arithmetic/limbs_extmul.nim b/constantine/arithmetic/limbs_extmul.nim index a674cb8bc..979de7a05 100644 --- a/constantine/arithmetic/limbs_extmul.nim +++ b/constantine/arithmetic/limbs_extmul.nim @@ -72,7 +72,8 @@ func prod*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b: ## `r` must not alias ``a`` or ``b`` when UseASM_X86_64 and aLen <= 6: - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): mul_asm_adx_bmi2(r, a, b) else: mul_asm(r, a, b) diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index a455c5141..ec90b46f3 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -421,7 +421,8 @@ func montyMul*( # - keep it generic and optimize code size when spareBits >= 1: when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) else: montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv) @@ -466,7 +467,8 @@ func montyRedc2x*[N: static int]( m0ninv: BaseType, spareBits: static int) {.inline.} = ## Montgomery reduce a double-precision bigint modulo M when UseASM_X86_64 and r.len <= 6: - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits) else: montRed_asm(r, a, M, m0ninv, spareBits)