From d7aa178d6f4b9f0d6967bb969fce12377a21f896 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 21 Sep 2024 14:39:08 +0000 Subject: [PATCH 01/14] feat: asm Vector sum slower, no avx --- ecc/bls12-377/fp/element_test.go | 17 +++++ ecc/bls12-377/fp/vector.go | 12 ++++ ecc/bls12-377/fr/element_ops_amd64.go | 12 ++++ ecc/bls12-377/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/bls12-377/fr/element_ops_purego.go | 6 ++ ecc/bls12-377/fr/element_test.go | 17 +++++ ecc/bls12-377/fr/vector.go | 6 ++ ecc/bls12-381/fp/element_test.go | 17 +++++ ecc/bls12-381/fp/vector.go | 12 ++++ ecc/bls12-381/fr/element_ops_amd64.go | 12 ++++ ecc/bls12-381/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/bls12-381/fr/element_ops_purego.go | 6 ++ ecc/bls12-381/fr/element_test.go | 17 +++++ ecc/bls12-381/fr/vector.go | 6 ++ ecc/bls24-315/fp/element_test.go | 17 +++++ ecc/bls24-315/fp/vector.go | 12 ++++ ecc/bls24-315/fr/element_ops_amd64.go | 12 ++++ ecc/bls24-315/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/bls24-315/fr/element_ops_purego.go | 6 ++ ecc/bls24-315/fr/element_test.go | 17 +++++ ecc/bls24-315/fr/vector.go | 6 ++ ecc/bls24-317/fp/element_test.go | 17 +++++ ecc/bls24-317/fp/vector.go | 12 ++++ ecc/bls24-317/fr/element_ops_amd64.go | 12 ++++ ecc/bls24-317/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/bls24-317/fr/element_ops_purego.go | 6 ++ ecc/bls24-317/fr/element_test.go | 17 +++++ ecc/bls24-317/fr/vector.go | 6 ++ ecc/bn254/fp/element_ops_amd64.go | 12 ++++ ecc/bn254/fp/element_ops_amd64.s | 42 +++++++++++++ ecc/bn254/fp/element_ops_purego.go | 6 ++ ecc/bn254/fp/element_test.go | 17 +++++ ecc/bn254/fp/vector.go | 6 ++ ecc/bn254/fr/element_ops_amd64.go | 12 ++++ ecc/bn254/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/bn254/fr/element_ops_purego.go | 6 ++ ecc/bn254/fr/element_test.go | 17 +++++ ecc/bn254/fr/vector.go | 6 ++ ecc/bw6-633/fp/element_test.go | 17 +++++ ecc/bw6-633/fp/vector.go | 12 ++++ ecc/bw6-633/fr/element_test.go | 17 +++++ ecc/bw6-633/fr/vector.go | 12 ++++ ecc/bw6-761/fp/element_test.go | 17 +++++ ecc/bw6-761/fp/vector.go | 12 ++++ ecc/bw6-761/fr/element_test.go | 17 +++++ ecc/bw6-761/fr/vector.go | 12 ++++ ecc/secp256k1/fp/element_ops_purego.go | 6 ++ ecc/secp256k1/fp/element_test.go | 17 +++++ ecc/secp256k1/fp/vector.go | 6 ++ ecc/secp256k1/fr/element_ops_purego.go | 6 ++ ecc/secp256k1/fr/element_test.go | 17 +++++ ecc/secp256k1/fr/vector.go | 6 ++ ecc/stark-curve/fp/element_ops_amd64.go | 12 ++++ ecc/stark-curve/fp/element_ops_amd64.s | 42 +++++++++++++ ecc/stark-curve/fp/element_ops_purego.go | 6 ++ ecc/stark-curve/fp/element_test.go | 17 +++++ ecc/stark-curve/fp/vector.go | 6 ++ ecc/stark-curve/fr/element_ops_amd64.go | 12 ++++ ecc/stark-curve/fr/element_ops_amd64.s | 42 +++++++++++++ ecc/stark-curve/fr/element_ops_purego.go | 6 ++ ecc/stark-curve/fr/element_test.go | 17 +++++ ecc/stark-curve/fr/vector.go | 6 ++ field/generator/asm/amd64/build.go | 1 + field/generator/asm/amd64/element_vec.go | 62 +++++++++++++++++++ .../internal/templates/element/ops_asm.go | 12 ++++ .../internal/templates/element/ops_purego.go | 6 ++ .../internal/templates/element/tests.go | 17 +++++ .../internal/templates/element/vector.go | 12 ++++ field/goldilocks/element_test.go | 17 +++++ field/goldilocks/vector.go | 12 ++++ go.mod | 4 +- 71 files changed, 1096 insertions(+), 1 deletion(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 582d8b4af..72c8af2ae 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -746,6 +746,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -780,6 +788,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index 0df05e337..94bd5e035 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -219,6 +219,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +252,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index ffa3b7bca..175325820 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index 9c34ebecc..446bec2e2 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 9b4190285..66b762147 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls12-377/fr/vector.go b/ecc/bls12-377/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/bls12-377/fr/vector.go +++ b/ecc/bls12-377/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index d070a1814..da7263c5a 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -746,6 +746,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -780,6 +788,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls12-381/fp/vector.go b/ecc/bls12-381/fp/vector.go index 0df05e337..94bd5e035 100644 --- a/ecc/bls12-381/fp/vector.go +++ b/ecc/bls12-381/fp/vector.go @@ -219,6 +219,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +252,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index caffb72b1..bb51bb446 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 50e839865..3c494940b 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 684ea1525..a0ce1d6ea 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls12-381/fr/vector.go b/ecc/bls12-381/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/bls12-381/fr/vector.go +++ b/ecc/bls12-381/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 665ffce6a..31f018b1b 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -744,6 +744,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -778,6 +786,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls24-315/fp/vector.go b/ecc/bls24-315/fp/vector.go index 01b326d49..7c9d8b2f3 100644 --- a/ecc/bls24-315/fp/vector.go +++ b/ecc/bls24-315/fp/vector.go @@ -218,6 +218,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +251,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 2e52c653b..bee861d35 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index 7b6cfd87b..2149e5d8b 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index ac030b6d0..8e3b19942 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls24-315/fr/vector.go b/ecc/bls24-315/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/bls24-315/fr/vector.go +++ b/ecc/bls24-315/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 7bbabe259..76758acc7 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -744,6 +744,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -778,6 +786,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls24-317/fp/vector.go b/ecc/bls24-317/fp/vector.go index 01b326d49..7c9d8b2f3 100644 --- a/ecc/bls24-317/fp/vector.go +++ b/ecc/bls24-317/fp/vector.go @@ -218,6 +218,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +251,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index fd237dad9..cb475094a 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index 14505483c..f18cc97c5 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index c533cc1c9..85c80207c 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bls24-317/fr/vector.go b/ecc/bls24-317/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/bls24-317/fr/vector.go +++ b/ecc/bls24-317/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index 6f16baf68..b7e3b3542 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index cbfba4ee5..86d45bce1 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 250ac5bce..8724acf1b 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index a923ef657..2ae7618d2 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bn254/fp/vector.go b/ecc/bn254/fp/vector.go index 850b3603d..1c29dcbf2 100644 --- a/ecc/bn254/fp/vector.go +++ b/ecc/bn254/fp/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index d077b1124..e5c890c82 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index cd5c53d8f..2d1ae1780 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 3be23d96a..059ab534f 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bn254/fr/vector.go b/ecc/bn254/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/bn254/fr/vector.go +++ b/ecc/bn254/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 169cd6701..103721a07 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -754,6 +754,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -788,6 +796,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bw6-633/fp/vector.go b/ecc/bw6-633/fp/vector.go index 1bd71a36e..22b8a0254 100644 --- a/ecc/bw6-633/fp/vector.go +++ b/ecc/bw6-633/fp/vector.go @@ -223,6 +223,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -250,6 +256,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index e232de8c8..b12fba199 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -744,6 +744,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -778,6 +786,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bw6-633/fr/vector.go b/ecc/bw6-633/fr/vector.go index 1c9b6b975..3331ee4d8 100644 --- a/ecc/bw6-633/fr/vector.go +++ b/ecc/bw6-633/fr/vector.go @@ -218,6 +218,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -245,6 +251,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index fbba1f286..083a0cda2 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -758,6 +758,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -792,6 +800,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bw6-761/fp/vector.go b/ecc/bw6-761/fp/vector.go index 87105028b..a363dd18b 100644 --- a/ecc/bw6-761/fp/vector.go +++ b/ecc/bw6-761/fp/vector.go @@ -225,6 +225,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -252,6 +258,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 0596297e8..22c72ba04 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -746,6 +746,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -780,6 +788,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/bw6-761/fr/vector.go b/ecc/bw6-761/fr/vector.go index 8dd4774c5..5749b4521 100644 --- a/ecc/bw6-761/fr/vector.go +++ b/ecc/bw6-761/fr/vector.go @@ -219,6 +219,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -246,6 +252,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go index a8624a511..5e8497f5b 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -75,6 +75,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 6f8165b18..0781f7dd3 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -740,6 +740,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -774,6 +782,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index 850b3603d..1c29dcbf2 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go index 1a46f6d79..a9a314406 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -75,6 +75,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index f554db8e3..85b23645b 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -740,6 +740,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -774,6 +782,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index 6f16baf68..b7e3b3542 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 914653b70..6f651dde6 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index 4906d13e0..c7a46aa0f 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 87e38f7c1..8da5b0e60 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/stark-curve/fp/vector.go b/ecc/stark-curve/fp/vector.go index 850b3603d..1c29dcbf2 100644 --- a/ecc/stark-curve/fp/vector.go +++ b/ecc/stark-curve/fp/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index 21568255d..88c2d0c58 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -81,6 +81,18 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { //go:noescape func scalarMulVec(res, a, b *Element, n uint64) +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *Element, a *Element, n uint64) + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 245dcb895..0fc4cfd36 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -625,3 +625,45 @@ noAdx_5: MOVQ AX, 48(SP) CALL ·scalarMulVecGeneric(SB) RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + +loop_8: + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + + // a[0] -> CX + // a[1] -> BX + // a[2] -> SI + // a[3] -> DI + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + ADDQ CX, R8 + ADCQ BX, R9 + ADCQ SI, R10 + ADCQ DI, R11 + + // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) + REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) + + // increment pointers to visit next element + ADDQ $32, AX + DECQ DX // decrement n + JMP loop_8 + +done_9: + MOVQ res+0(FP), AX + MOVQ R8, 0(AX) + MOVQ R9, 8(AX) + MOVQ R10, 16(AX) + MOVQ R11, 24(AX) + RET diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index b04f5202f..a45314560 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -78,6 +78,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + // Mul z = x * y (mod q) // // x and y must be less than q diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index b81aff116..6d0d2ec9e 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -742,6 +742,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -776,6 +784,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/ecc/stark-curve/fr/vector.go b/ecc/stark-curve/fr/vector.go index f39828547..d6a66b036 100644 --- a/ecc/stark-curve/fr/vector.go +++ b/ecc/stark-curve/fr/vector.go @@ -226,6 +226,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index b760ad3e3..dbb1fbcd5 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -163,6 +163,7 @@ func Generate(w io.Writer, F *config.FieldConfig) error { f.generateAddVec() f.generateSubVec() f.generateScalarMulVec() + f.generateSumVec() } return nil diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index 05c2cf3f1..32406f362 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -240,3 +240,65 @@ func (f *FFAmd64) generateScalarMulVec() { f.RET() } + +// sumVec res = sum(a[0...n]) +func (f *FFAmd64) generateSumVec() { + f.Comment("sumVec(res, a *Element, n uint64) res = sum(a[0...n])") + + const argSize = 3 * 8 + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) + registers := f.FnHeader("sumVec", stackSize, argSize) + defer f.AssertCleanStack(stackSize, 0) + + // registers & labels we need + addrA := f.Pop(®isters) + len := f.Pop(®isters) + + a := f.PopN(®isters) + t := f.PopN(®isters) + scratch := f.PopN(®isters) + + loop := f.NewLabel("loop") + done := f.NewLabel("done") + + // load arguments + f.MOVQ("a+8(FP)", addrA) + f.MOVQ("n+16(FP)", len) + + f.XORQ(t[0], t[0]) + f.XORQ(t[1], t[1]) + f.XORQ(t[2], t[2]) + f.XORQ(t[3], t[3]) + + f.LABEL(loop) + + f.TESTQ(len, len) + f.JEQ(done, "n == 0, we are done") + + // t += a + f.LabelRegisters("a", a...) + f.Mov(addrA, a) + f.Add(a, t) + + // reduce t + f.ReduceElement(t, scratch) + + f.Comment("increment pointers to visit next element") + f.ADDQ("$32", addrA) + f.DECQ(len, "decrement n") + f.JMP(loop) + + f.LABEL(done) + + // save t into res + f.MOVQ("res+0(FP)", addrA) + f.Mov(t, addrA) + + f.RET() + + f.Push(®isters, a...) + f.Push(®isters, t...) + f.Push(®isters, scratch...) + f.Push(®isters, addrA, len) + +} diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index ffa7231e1..d2aa9c887 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -65,6 +65,18 @@ func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { //go:noescape func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + if len(*vector) == 0 { + return + } + sumVec(&res, &(*vector)[0], uint64(len(*vector))) + return +} + +//go:noescape +func sumVec(res *{{.ElementName}}, a *{{.ElementName}}, n uint64) {{- end}} // Mul z = x * y (mod q) diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index a4fde0d05..d3340a8ff 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -68,6 +68,12 @@ func (vector *Vector) Sub(a, b Vector) { func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { scalarMulVecGeneric(*vector, a, b) } + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} {{- end}} // Mul z = x * y (mod q) diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 416b3f30e..cf7582b9d 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -763,6 +763,14 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum {{.ElementName}} + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { @@ -798,6 +806,15 @@ func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum {{.ElementName}} + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 8f06b54c9..6db71d7cd 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -211,6 +211,12 @@ func (vector *Vector) Sub(a, b Vector) { func (vector *Vector) ScalarMul(a Vector, b *{{.ElementName}}) { scalarMulVecGeneric(*vector, a, b) } + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res {{.ElementName}}) { + sumVecGeneric(&res, *vector) + return +} {{- end}} @@ -242,6 +248,12 @@ func scalarMulVecGeneric(res, a Vector, b *{{.ElementName}}) { } } +func sumVecGeneric(res *{{.ElementName}}, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 339fb4ea6..c3a106719 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -687,6 +687,14 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + + // Vector sum + var sum Element + computed := c.Sum() + for i := 0; i < N; i++ { + sum.Add(&sum, &c[i]) + } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { @@ -721,6 +729,15 @@ func BenchmarkElementVecOps(b *testing.B) { c1.ScalarMul(a1, &b1[0]) } }) + + b.Run("Sum", func(b *testing.B) { + b.ResetTimer() + var sum Element + for i := 0; i < b.N; i++ { + sum = c1.Sum() + } + _ = sum + }) } func TestElementAdd(t *testing.T) { diff --git a/field/goldilocks/vector.go b/field/goldilocks/vector.go index 3de71afb8..9383691dc 100644 --- a/field/goldilocks/vector.go +++ b/field/goldilocks/vector.go @@ -214,6 +214,12 @@ func (vector *Vector) ScalarMul(a Vector, b *Element) { scalarMulVecGeneric(*vector, a, b) } +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") @@ -241,6 +247,12 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { } } +func sumVecGeneric(res *Element, a Vector) { + for i := 0; i < len(a); i++ { + res.Add(res, &a[i]) + } +} + // TODO @gbotrel make a public package out of that. // execute executes the work function in parallel. // this is copy paste from internal/parallel/parallel.go diff --git a/go.mod b/go.mod index 1cc1f399b..d2ac46944 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.1.15 + github.com/consensys/bavard v0.0.0 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -14,6 +14,8 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) +replace github.com/consensys/bavard => ../bavard + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect From 6d2d8a874304d83912c4c1c6b480cfbeaf3c2795 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 21 Sep 2024 21:45:43 +0000 Subject: [PATCH 02/14] checkpoint --- ecc/bls12-377/fp/element_test.go | 27 +- ecc/bls12-377/fr/element_mul_amd64.s | 3 + ecc/bls12-377/fr/element_ops_amd64.s | 186 +++++++++-- ecc/bls12-377/fr/element_test.go | 27 +- ecc/bls12-381/fr/element_ops_amd64.s | 66 ++-- ecc/bls24-315/fr/element_ops_amd64.s | 66 ++-- ecc/bls24-317/fr/element_ops_amd64.s | 66 ++-- ecc/bn254/fp/element_ops_amd64.s | 98 ++++-- ecc/bn254/fr/element_ops_amd64.s | 98 ++++-- ecc/stark-curve/fp/element_ops_amd64.s | 66 ++-- ecc/stark-curve/fr/element_ops_amd64.s | 66 ++-- field/generator/asm/amd64/asm_macros.go | 6 + field/generator/asm/amd64/build.go | 4 + field/generator/asm/amd64/element_vec.go | 308 ++++++++++++++++-- field/generator/config/field_config.go | 11 + .../internal/templates/element/tests.go | 27 +- field/goldilocks/element_test.go | 27 +- internal/generator/main.go | 3 + 18 files changed, 873 insertions(+), 282 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 72c8af2ae..9b0b3258e 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,10 +749,35 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element + const maxUint64 = ^uint64(0) + for i := 0; i < N; i++ { + c[i][0] = maxUint64 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } + for i := 2; i < N; i++ { + c[i][0] = 0 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + // print computed[0], computed[1] in 64bit binary string + fmt.Printf("computed[0]: %64b\n", computed[0]) + fmt.Printf("computed[1]: %64b\n", computed[1]) + fmt.Printf("computed[2]: %64b\n", computed[2]) + fmt.Printf("computed[3]: %64b\n", computed[3]) + + // print the sum[0], sum[1] in 64bit binary string + fmt.Printf("sum [0]: %64b\n", sum[0]) + fmt.Printf("sum [1]: %64b\n", sum[1]) + fmt.Printf("sum [2]: %64b\n", sum[2]) + fmt.Printf("sum [3]: %64b\n", sum[3]) + assert.True(sum.Equal(&computed), "Vector sum failed") } diff --git a/ecc/bls12-377/fr/element_mul_amd64.s b/ecc/bls12-377/fr/element_mul_amd64.s index ab1816245..a8df29c64 100644 --- a/ecc/bls12-377/fr/element_mul_amd64.s +++ b/ecc/bls12-377/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x0a117fffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000000db65247b1 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 175325820..36501ea2c 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x0a117fffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000000db65247b1 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,42 +631,159 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ $0x1555, CX + KMOVW CX, K1 + MOVQ $0xff80, CX + KMOVW CX, K2 + MOVQ $0x01ff, CX + KMOVW CX, K3 + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + TESTQ R15, R15 + JEQ done_9 // n == 0, we are done + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPB CX, $1 + JEQ rr1_10 // we have 1 remaining element + CMPB CX, $2 + JEQ rr2_11 // we have 2 remaining elements + CMPB CX, $3 + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + +rr2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + +rr1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n JMP loop_8 +accumulate_12: + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + XORQ AX, AX + MOVQ SI, AX + ANDQ $0xffffffff, AX + SHLQ $32, AX + SHRQ $32, SI + ADOXQ AX, BX + MOVQ R8, AX + ANDQ $0xffffffff, AX + SHLQ $32, AX + SHRQ $32, R8 + ADOXQ AX, DI + ADCXQ SI, DI + MOVQ R10, AX + ANDQ $0xffffffff, AX + SHLQ $32, AX + SHRQ $32, R10 + ADOXQ AX, R9 + ADCXQ R8, R9 + MOVQ R12, AX + ANDQ $0xffffffff, AX + SHLQ $32, AX + SHRQ $32, R12 + ADOXQ AX, R11 + ADCXQ R10, R11 + MOVQ $0, AX + ADOXQ AX, R12 + ADCXQ AX, R12 + MOVQ res+0(FP), R14 + MOVQ BX, 0(R14) + MOVQ DI, 8(R14) + MOVQ R9, 16(R14) + MOVQ R11, 24(R14) + RET + MOVQ mu<>(SB), CX + MOVQ R8, AX + SHRQ $32, R13, AX + MULQ CX + MULXQ q<>+0(SB), AX, CX + SUBQ AX, BX + SBBQ CX, SI + MULXQ q<>+16(SB), AX, CX + SBBQ AX, DI + SBBQ CX, R8 + SBBQ $0, R13 + MULXQ q<>+8(SB), AX, CX + SUBQ AX, SI + SBBQ CX, DI + MULXQ q<>+24(SB), AX, CX + SBBQ AX, R8 + SBBQ CX, R13 + MOVQ $0x0a11800000000001, R9 + MOVQ $0x59aa76fed0000001, R10 + MOVQ $0x60b44d1e5c37b001, R11 + MOVQ $0x12ab655e9a2ca556, R12 + MOVQ res+0(FP), R14 + MOVQ BX, 0(R14) + MOVQ SI, 8(R14) + MOVQ DI, 16(R14) + MOVQ R8, 24(R14) + SUBQ R9, BX + SBBQ R10, SI + SBBQ R11, DI + SBBQ R12, R8 + SBBQ $0, R13 + JCS done_9 + MOVQ BX, 0(R14) + MOVQ SI, 8(R14) + MOVQ DI, 16(R14) + MOVQ R8, 24(R14) + SUBQ R9, BX + SBBQ R10, SI + SBBQ R11, DI + SBBQ R12, R8 + SBBQ $0, R13 + JCS done_9 + MOVQ BX, 0(R14) + MOVQ SI, 8(R14) + MOVQ DI, 16(R14) + MOVQ R8, 24(R14) + done_9: - MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 66b762147..8cfdfd8dc 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -745,10 +745,35 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element + const maxUint64 = ^uint64(0) + for i := 0; i < N; i++ { + c[i][0] = maxUint64 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } + for i := 2; i < N; i++ { + c[i][0] = 0 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + // print computed[0], computed[1] in 64bit binary string + fmt.Printf("computed[0]: %64b\n", computed[0]) + fmt.Printf("computed[1]: %64b\n", computed[1]) + fmt.Printf("computed[2]: %64b\n", computed[2]) + fmt.Printf("computed[3]: %64b\n", computed[3]) + + // print the sum[0], sum[1] in 64bit binary string + fmt.Printf("sum [0]: %64b\n", sum[0]) + fmt.Printf("sum [1]: %64b\n", sum[1]) + fmt.Printf("sum [2]: %64b\n", sum[2]) + fmt.Printf("sum [3]: %64b\n", sum[3]) + assert.True(sum.Equal(&computed), "Vector sum failed") } diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index bb51bb446..b9a186e46 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -628,42 +628,46 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORQ ZMM0, ZMM0, ZMM0 + VMOVDQA64 ZMM0, ZMM1 + VMOVDQA64 ZMM0, ZMM2 + VMOVDQA64 ZMM0, ZMM3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPQ $1, CX + JEQ r1_10 // we have 1 remaining element + CMPQ $2, CX + JEQ r2_11 // we have 2 remaining elements + CMPQ $3, CX + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), ZMM4 + VPADDQ ZMM4, ZMM0, ZMM0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), ZMM4 + VPADDQ ZMM4, ZMM1, ZMM1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), ZMM4 + VPADDQ ZMM4, ZMM2, ZMM2 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index bee861d35..8f4991b50 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -628,42 +628,46 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORQ ZMM0, ZMM0, ZMM0 + VMOVDQA64 ZMM0, ZMM1 + VMOVDQA64 ZMM0, ZMM2 + VMOVDQA64 ZMM0, ZMM3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPQ $1, CX + JEQ r1_10 // we have 1 remaining element + CMPQ $2, CX + JEQ r2_11 // we have 2 remaining elements + CMPQ $3, CX + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), ZMM4 + VPADDQ ZMM4, ZMM0, ZMM0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), ZMM4 + VPADDQ ZMM4, ZMM1, ZMM1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), ZMM4 + VPADDQ ZMM4, ZMM2, ZMM2 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index cb475094a..ba8afdd99 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -628,42 +628,46 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORQ ZMM0, ZMM0, ZMM0 + VMOVDQA64 ZMM0, ZMM1 + VMOVDQA64 ZMM0, ZMM2 + VMOVDQA64 ZMM0, ZMM3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPQ $1, CX + JEQ r1_10 // we have 1 remaining element + CMPQ $2, CX + JEQ r2_11 // we have 2 remaining elements + CMPQ $3, CX + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), ZMM4 + VPADDQ ZMM4, ZMM0, ZMM0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), ZMM4 + VPADDQ ZMM4, ZMM1, ZMM1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), ZMM4 + VPADDQ ZMM4, ZMM2, ZMM2 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 86d45bce1..6cb600960 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -628,42 +628,76 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ $0x1555, CX + KMOVW CX, K1 + MOVQ $0xff80, CX + KMOVW CX, K2 + MOVQ $0x01ff, CX + KMOVW CX, K3 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPB CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPB CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPB CX, $3 + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), Z4 + VPADDQ Z4, Z0, Z0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), Z4 + VPADDQ Z4, Z1, Z1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), Z4 + VPADDQ Z4, Z2, Z2 loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + TESTQ DX, DX + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(AX), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(AX), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(AX), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(AX), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 +accumulate_12: + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + MOVQ $8, CX + VALIGND $1, Z3, Z0, K2, Z3 + +propagate_13: + VPSRLQ $32, Z0, Z1 + VALIGND $2, Z0, Z0, K1, Z0 + VPADDQ Z1, Z0, Z0 + VALIGND $1, Z3, Z0, K2, Z3 + DECQ CX + JNE propagate_13 + done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index e5c890c82..ede44f461 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -628,42 +628,76 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ $0x1555, CX + KMOVW CX, K1 + MOVQ $0xff80, CX + KMOVW CX, K2 + MOVQ $0x01ff, CX + KMOVW CX, K3 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPB CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPB CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPB CX, $3 + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), Z4 + VPADDQ Z4, Z0, Z0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), Z4 + VPADDQ Z4, Z1, Z1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), Z4 + VPADDQ Z4, Z2, Z2 loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + TESTQ DX, DX + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(AX), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(AX), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(AX), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(AX), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 +accumulate_12: + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + MOVQ $8, CX + VALIGND $1, Z3, Z0, K2, Z3 + +propagate_13: + VPSRLQ $32, Z0, Z1 + VALIGND $2, Z0, Z0, K1, Z0 + VPADDQ Z1, Z0, Z0 + VALIGND $1, Z3, Z0, K2, Z3 + DECQ CX + JNE propagate_13 + done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 6f651dde6..bd1ce4e36 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -628,42 +628,46 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORQ ZMM0, ZMM0, ZMM0 + VMOVDQA64 ZMM0, ZMM1 + VMOVDQA64 ZMM0, ZMM2 + VMOVDQA64 ZMM0, ZMM3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPQ $1, CX + JEQ r1_10 // we have 1 remaining element + CMPQ $2, CX + JEQ r2_11 // we have 2 remaining elements + CMPQ $3, CX + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), ZMM4 + VPADDQ ZMM4, ZMM0, ZMM0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), ZMM4 + VPADDQ ZMM4, ZMM1, ZMM1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), ZMM4 + VPADDQ ZMM4, ZMM2, ZMM2 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 0fc4cfd36..ee525581c 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -628,42 +628,46 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - XORQ R8, R8 - XORQ R9, R9 - XORQ R10, R10 - XORQ R11, R11 + MOVQ a+8(FP), AX + MOVQ n+16(FP), DX + VXORQ ZMM0, ZMM0, ZMM0 + VMOVDQA64 ZMM0, ZMM1 + VMOVDQA64 ZMM0, ZMM2 + VMOVDQA64 ZMM0, ZMM3 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done + MOVQ DX, CX + ANDQ $3, CX + SHRQ $2, DX + CMPQ $1, CX + JEQ r1_10 // we have 1 remaining element + CMPQ $2, CX + JEQ r2_11 // we have 2 remaining elements + CMPQ $3, CX + JNE loop_8 // == 0; we have 0 remaining elements + + // we have 3 remaining elements + VPMOVZXDQ 2*32(AX), ZMM4 + VPADDQ ZMM4, ZMM0, ZMM0 + +r2_11: + // we have 2 remaining elements + VPMOVZXDQ 1*32(AX), ZMM4 + VPADDQ ZMM4, ZMM1, ZMM1 + +r1_10: + // we have 1 remaining element + VPMOVZXDQ 0*32(AX), ZMM4 + VPADDQ ZMM4, ZMM2, ZMM2 + TESTQ DX, DX + JEQ done_9 // n == 0, we are done loop_8: - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - - // a[0] -> CX - // a[1] -> BX - // a[2] -> SI - // a[3] -> DI - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, R8 - ADCQ BX, R9 - ADCQ SI, R10 - ADCQ DI, R11 - - // reduce element(R8,R9,R10,R11) using temp registers (R12,R13,R14,R15) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15) - // increment pointers to visit next element - ADDQ $32, AX - DECQ DX // decrement n + ADDQ $128, AX + DECQ DX // decrement n JMP loop_8 done_9: MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) RET diff --git a/field/generator/asm/amd64/asm_macros.go b/field/generator/asm/amd64/asm_macros.go index 45d324c94..676ff0fb7 100644 --- a/field/generator/asm/amd64/asm_macros.go +++ b/field/generator/asm/amd64/asm_macros.go @@ -74,6 +74,12 @@ GLOBL q<>(SB), (RODATA+NOPTR), ${{mul 8 $.NbWords}} DATA qInv0<>(SB)/8, {{$qinv0 := index .QInverse 0}}{{imm $qinv0}} GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +{{- if eq .NbWords 4}} +// Mu +DATA mu<>(SB)/8, {{imm .Mu}} +GLOBL mu<>(SB), (RODATA+NOPTR), $8 +{{- end}} + #define REDUCE( {{- range $i := .NbWordsIndexesFull}}ra{{$i}},{{- end}} {{- range $i := .NbWordsIndexesFull}}rb{{$i}}{{- if ne $.NbWordsLastIndex $i}},{{- end}}{{- end}}) \ MOVQ ra0, rb0; \ diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index dbb1fbcd5..9a9a65cf7 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -135,6 +135,10 @@ func (f *FFAmd64) qInv0() string { return "qInv0<>(SB)" } +func (f *FFAmd64) mu() string { + return "mu<>(SB)" +} + // Generate generates assembly code for the base field provided to goff // see internal/templates/ops* func Generate(w io.Writer, F *config.FieldConfig) error { diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index 32406f362..4f74166db 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -14,7 +14,9 @@ package amd64 -import "github.com/consensys/bavard/amd64" +import ( + "github.com/consensys/bavard/amd64" +) // addVec res = a + b // func addVec(res, a, b *{{.ElementName}}, n uint64) @@ -246,59 +248,313 @@ func (f *FFAmd64) generateSumVec() { f.Comment("sumVec(res, a *Element, n uint64) res = sum(a[0...n])") const argSize = 3 * 8 - stackSize := f.StackSize(f.NbWords*3+2, 0, 0) - registers := f.FnHeader("sumVec", stackSize, argSize) + stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) defer f.AssertCleanStack(stackSize, 0) // registers & labels we need addrA := f.Pop(®isters) len := f.Pop(®isters) + tmp0 := f.Pop(®isters) - a := f.PopN(®isters) t := f.PopN(®isters) - scratch := f.PopN(®isters) + s := f.PopN(®isters) + t4 := f.Pop(®isters) loop := f.NewLabel("loop") done := f.NewLabel("done") + rr1 := f.NewLabel("rr1") + rr2 := f.NewLabel("rr2") + accumulate := f.NewLabel("accumulate") + // propagate := f.NewLabel("propagate") + + // AVX512 registers + Z0 := amd64.Register("Z0") + Z1 := amd64.Register("Z1") + Z2 := amd64.Register("Z2") + Z3 := amd64.Register("Z3") + Z4 := amd64.Register("Z4") + X0 := amd64.Register("X0") + + K1 := amd64.Register("K1") + K2 := amd64.Register("K2") + K3 := amd64.Register("K3") + + f.MOVQ("$0x1555", tmp0) + f.KMOVW(tmp0, K1) + + f.MOVQ("$0xff80", tmp0) + f.KMOVW(tmp0, K2) + + f.MOVQ("$0x01ff", tmp0) + f.KMOVW(tmp0, K3) // load arguments f.MOVQ("a+8(FP)", addrA) f.MOVQ("n+16(FP)", len) - f.XORQ(t[0], t[0]) - f.XORQ(t[1], t[1]) - f.XORQ(t[2], t[2]) - f.XORQ(t[3], t[3]) - - f.LABEL(loop) + // initialize accumulators to zero (zmm0, zmm1, zmm2, zmm3) + f.VXORPS(Z0, Z0, Z0) + f.VMOVDQA64(Z0, Z1) + f.VMOVDQA64(Z0, Z2) + f.VMOVDQA64(Z0, Z3) f.TESTQ(len, len) f.JEQ(done, "n == 0, we are done") - // t += a - f.LabelRegisters("a", a...) - f.Mov(addrA, a) - f.Add(a, t) + f.MOVQ(len, tmp0) + f.ANDQ("$3", tmp0) // t0 = n % 4 + f.SHRQ("$2", len) // len = n / 4 - // reduce t - f.ReduceElement(t, scratch) + // if len % 4 != 0, we need to handle the remaining elements + f.CMPB(tmp0, "$1") + f.JEQ(rr1, "we have 1 remaining element") - f.Comment("increment pointers to visit next element") - f.ADDQ("$32", addrA) + f.CMPB(tmp0, "$2") + f.JEQ(rr2, "we have 2 remaining elements") + + f.CMPB(tmp0, "$3") + f.JNE(loop, "== 0; we have 0 remaining elements") + + f.Comment("we have 3 remaining elements") + // vpmovzxdq 2*32(PX), %zmm4; vpaddq %zmm4, %zmm0, %zmm0 + f.VPMOVZXDQ("2*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z0, Z0) + + f.LABEL(rr2) + f.Comment("we have 2 remaining elements") + // vpmovzxdq 1*32(PX), %zmm4; vpaddq %zmm4, %zmm1, %zmm1 + f.VPMOVZXDQ("1*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z1, Z1) + + f.LABEL(rr1) + f.Comment("we have 1 remaining element") + // vpmovzxdq 0*32(PX), %zmm4; vpaddq %zmm4, %zmm2, %zmm2 + f.VPMOVZXDQ("0*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z2, Z2) + + f.LABEL(loop) + f.TESTQ(len, len) + f.JEQ(accumulate, "n == 0, we are going to accumulate") + + f.VPMOVZXDQ("0*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z0, Z0) + + f.VPMOVZXDQ("1*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z1, Z1) + + f.VPMOVZXDQ("2*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z2, Z2) + + f.VPMOVZXDQ("3*32("+addrA+")", Z4) + f.VPADDQ(Z4, Z3, Z3) + + f.Comment("increment pointers to visit next 4 elements") + f.ADDQ("$128", addrA) f.DECQ(len, "decrement n") f.JMP(loop) - f.LABEL(done) + f.LABEL(accumulate) + + f.VPADDQ(Z1, Z0, Z0) + f.VPADDQ(Z3, Z2, Z2) + f.VPADDQ(Z2, Z0, Z0) + + // Propagate carries + f.VMOVQ(X0, t[0]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, t[1]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, t[2]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, t[3]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, s[0]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, s[1]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, s[2]) + f.VALIGNQ("$1", Z0, Z0, Z0) + f.VMOVQ(X0, s[3]) + + w0l := t[0] + w0h := t[1] + w1l := t[2] + w1h := t[3] + w2l := s[0] + w2h := s[1] + w3l := s[2] + w3h := s[3] + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + r0 := w0l + r1 := w1l + r2 := w2l + r3 := w3l + + // we need 2 carry so we use ADOXQ and ADCXQ + f.XORQ(amd64.AX, amd64.AX) + + // get low bits of w0h + f.MOVQ(w0h, amd64.AX) + f.ANDQ("$0xffffffff", amd64.AX) + f.SHLQ("$32", amd64.AX) + f.SHRQ("$32", w0h) + + // start the carry chain + f.ADOXQ(amd64.AX, w0l) // w0l is good. + + // get low bits of w1h + f.MOVQ(w1h, amd64.AX) + f.ANDQ("$0xffffffff", amd64.AX) + f.SHLQ("$32", amd64.AX) + f.SHRQ("$32", w1h) + + f.ADOXQ(amd64.AX, w1l) + f.ADCXQ(w0h, w1l) + + // get low bits of w2h + f.MOVQ(w2h, amd64.AX) + f.ANDQ("$0xffffffff", amd64.AX) + f.SHLQ("$32", amd64.AX) + f.SHRQ("$32", w2h) + + f.ADOXQ(amd64.AX, w2l) + f.ADCXQ(w1h, w2l) + + // get low bits of w3h + f.MOVQ(w3h, amd64.AX) + f.ANDQ("$0xffffffff", amd64.AX) + f.SHLQ("$32", amd64.AX) + f.SHRQ("$32", w3h) + + f.ADOXQ(amd64.AX, w3l) + f.ADCXQ(w2h, w3l) + r4 := w3h + f.MOVQ("$0", amd64.AX) + f.ADOXQ(amd64.AX, r4) + f.ADCXQ(amd64.AX, r4) + + // // we use AX for low 32bits + // f.MOVQ(t[1], amd64.AX) + // f.ANDQ("$0xffffffff", amd64.AX) + // f.SHRQ("$32", t[1]) + + // // start the carry chain + // f.ADDQ(amd64.AX, t[0]) // t0 is good. + // // now t1, we have to add t1 + low(t2) + + // // // Propagate carries + // // mov $8, %eax + // // valignd $1, %zmm3, %zmm0, %zmm3{%k2}{z} // Shift lowest dword of zmm0 into zmm3 + // f.MOVQ("$8", tmp0) + // f.VALIGND("$1", Z3, Z0, K2, Z3) + + // f.LABEL(propagate) + // f.VPSRLQ("$32", Z0, Z1) + // f.VALIGND("$2", Z0, Z0, K1, Z0) + // f.VPADDQ(Z1, Z0, Z0) + // f.VALIGND("$1", Z3, Z0, K2, Z3) + + // f.DECQ(tmp0) + // f.JNE(propagate) + + // // The top 9 dwords of zmm3 now contain the sum + // // we shift by 224 bits to get the result in the low 32bytes + + // // // Move intermediate result to integer registers + // // The top 9 dwords of zmm3 now contain the sum. Copy them to the low end of zmm0. + // // valignd $7, %zmm3, %zmm3, %zmm0{%k3}{z} + // // // Copy to integer registers + // // vmovq %xmm0, T0; valignq $1, %zmm0, %zmm0, %zmm0 + // // vmovq %xmm0, T1; valignq $1, %zmm0, %zmm0, %zmm0 + // // vmovq %xmm0, T2; valignq $1, %zmm0, %zmm0, %zmm0 + // // vmovq %xmm0, T3; valignq $1, %zmm0, %zmm0, %zmm0 + // // vmovq %xmm0, T4 + + // f.VALIGND("$7", Z3, Z3, K3, Z0) + + // f.VMOVQ(X0, t[0]) + // f.VALIGNQ("$1", Z0, Z0, Z0) + // f.VMOVQ(X0, t[1]) + // f.VALIGNQ("$1", Z0, Z0, Z0) + // f.VMOVQ(X0, t[2]) + // f.VALIGNQ("$1", Z0, Z0, Z0) + // f.VMOVQ(X0, t[3]) + // f.VALIGNQ("$1", Z0, Z0, Z0) + // f.VMOVQ(X0, t4) - // save t into res f.MOVQ("res+0(FP)", addrA) - f.Mov(t, addrA) + r := []amd64.Register{r0, r1, r2, r3} + f.Mov(r, addrA) f.RET() - f.Push(®isters, a...) - f.Push(®isters, t...) - f.Push(®isters, scratch...) - f.Push(®isters, addrA, len) + // Reduce using single-word Barrett + // q1 is low 32 bits of T4 and high 32 bits of T3 + // movq T3, %rax + // shrd $32, T4, %rax + // mulq MU // Multiply by mu. q2 in rdx:rax, q3 in rdx + f.MOVQ(f.mu(), tmp0) + f.MOVQ(t[3], amd64.AX) + f.SHRQw("$32", t4, amd64.AX) + f.MULQ(tmp0) + + // Subtract r2 from r1 + // mulx 0*8(PM), PL, PH; sub PL, T0; sbb PH, T1; + // mulx 2*8(PM), PL, PH; sbb PL, T2; sbb PH, T3; sbb $0, T4 + // mulx 1*8(PM), PL, PH; sub PL, T1; sbb PH, T2; + // mulx 3*8(PM), PL, PH; sbb PL, T3; sbb PH, T4 + f.MULXQ(f.qAt(0), amd64.AX, tmp0) + f.SUBQ(amd64.AX, t[0]) + f.SBBQ(tmp0, t[1]) + + f.MULXQ(f.qAt(2), amd64.AX, tmp0) + f.SBBQ(amd64.AX, t[2]) + f.SBBQ(tmp0, t[3]) + f.SBBQ("$0", t4) + + f.MULXQ(f.qAt(1), amd64.AX, tmp0) + f.SUBQ(amd64.AX, t[1]) + f.SBBQ(tmp0, t[2]) + + f.MULXQ(f.qAt(3), amd64.AX, tmp0) + f.SBBQ(amd64.AX, t[3]) + f.SBBQ(tmp0, t4) + + // Two conditional subtractions to guarantee canonicity of the result + // substract modulus from t + f.Mov(f.Q, s) + + f.MOVQ("res+0(FP)", addrA) + f.Mov(t, addrA) + + f.Sub(s, t) + f.SBBQ("$0", t4) + // if borrow, skip to end + f.JCS(done) + + f.Mov(t, addrA) + f.Sub(s, t) + f.SBBQ("$0", t4) + // if borrow, skip to end + f.JCS(done) + f.Mov(t, addrA) + + // save t into res + + f.LABEL(done) + + // save t into res + + // f.Mov(t, addrA) + + f.RET() + f.Push(®isters, addrA, len, tmp0, t4) + f.Push(®isters, t...) + f.Push(®isters, s...) } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 457a89d7d..5da7da0a3 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -51,6 +51,7 @@ type FieldConfig struct { Q []uint64 QInverse []uint64 QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 + Mu uint64 // mu = 2^288 / q for barrett reduction ASM bool RSquare []uint64 One, Thirteen []uint64 @@ -117,6 +118,16 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) _qInv.Mod(_qInv, _r) F.QInverse = toUint64Slice(_qInv, F.NbWords) + // setting Mu 2^288 / q + if F.NbWords == 4 { + // TODO @gbotrel clean for all modulus. + _mu := big.NewInt(1) + _mu.Lsh(_mu, 288) + _mu.Div(_mu, &bModulus) + muSlice := toUint64Slice(_mu, F.NbWords) + F.Mu = muSlice[0] + } + // Pornin20 inversion correction factors k := 32 // Optimized for 64 bit machines, still works for 32 diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index cf7582b9d..5dd15cdcd 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -766,10 +766,35 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { // Vector sum var sum {{.ElementName}} + const maxUint64 = ^uint64(0) + for i := 0; i < N; i++ { + c[i][0] = maxUint64 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } + for i := 2; i < N; i++ { + c[i][0] = 0 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + // print computed[0], computed[1] in 64bit binary string + fmt.Printf("computed[0]: %64b\n", computed[0]) + fmt.Printf("computed[1]: %64b\n", computed[1]) + fmt.Printf("computed[2]: %64b\n", computed[2]) + fmt.Printf("computed[3]: %64b\n", computed[3]) + + // print the sum[0], sum[1] in 64bit binary string + fmt.Printf("sum [0]: %64b\n", sum[0]) + fmt.Printf("sum [1]: %64b\n", sum[1]) + fmt.Printf("sum [2]: %64b\n", sum[2]) + fmt.Printf("sum [3]: %64b\n", sum[3]) + assert.True(sum.Equal(&computed), "Vector sum failed") } diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index c3a106719..95857eb99 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -690,10 +690,35 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element + const maxUint64 = ^uint64(0) + for i := 0; i < N; i++ { + c[i][0] = maxUint64 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } + for i := 2; i < N; i++ { + c[i][0] = 0 + c[i][1] = 0 + c[i][2] = 0 + c[i][3] = 0 + } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + // print computed[0], computed[1] in 64bit binary string + fmt.Printf("computed[0]: %64b\n", computed[0]) + fmt.Printf("computed[1]: %64b\n", computed[1]) + fmt.Printf("computed[2]: %64b\n", computed[2]) + fmt.Printf("computed[3]: %64b\n", computed[3]) + + // print the sum[0], sum[1] in 64bit binary string + fmt.Printf("sum [0]: %64b\n", sum[0]) + fmt.Printf("sum [1]: %64b\n", sum[1]) + fmt.Printf("sum [2]: %64b\n", sum[2]) + fmt.Printf("sum [3]: %64b\n", sum[3]) + assert.True(sum.Equal(&computed), "Vector sum failed") } diff --git a/internal/generator/main.go b/internal/generator/main.go index 389f96c2e..ad3a46039 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -46,6 +46,9 @@ func main() { var wg sync.WaitGroup for _, conf := range config.Curves { + if !conf.Equal(config.BLS12_377) { + continue + } wg.Add(1) // for each curve, generate the needed files go func(conf config.Curve) { From 0bb00a0cdf0eb45be7e66818466f50fe4edc5e8b Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 21 Sep 2024 22:46:51 +0000 Subject: [PATCH 03/14] checkpoint --- ecc/bls12-377/fp/element_test.go | 36 +-- ecc/bls12-377/fr/element_ops_amd64.s | 140 ++++----- ecc/bls12-377/fr/element_test.go | 36 +-- field/generator/asm/amd64/element_vec.go | 281 +++++++----------- .../internal/templates/element/tests.go | 36 +-- field/goldilocks/element_test.go | 36 +-- 6 files changed, 243 insertions(+), 322 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 9b0b3258e..d289c9c24 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 + const N = 4 * 16 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,28 +749,28 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - const maxUint64 = ^uint64(0) - for i := 0; i < N; i++ { - c[i][0] = maxUint64 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } - for i := 2; i < N; i++ { - c[i][0] = 0 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } + // const maxUint64 = ^uint64(0) + // for i := 0; i < N; i++ { + // c[i][0] = maxUint64 + // c[i][1] = 0 + // // c[i][2] = 0 + // c[i][3] = 0 + // } + // // for i := 2; i < N; i++ { + // // c[i][0] = 0 + // // c[i][1] = 0 + // // c[i][2] = 0 + // // c[i][3] = 0 + // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b\n", computed[0]) - fmt.Printf("computed[1]: %64b\n", computed[1]) - fmt.Printf("computed[2]: %64b\n", computed[2]) - fmt.Printf("computed[3]: %64b\n", computed[3]) + fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) + fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) + fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) + fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) // print the sum[0], sum[1] in 64bit binary string fmt.Printf("sum [0]: %64b\n", sum[0]) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 36501ea2c..28af7d7a2 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -631,12 +631,7 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ $0x1555, CX - KMOVW CX, K1 - MOVQ $0xff80, CX - KMOVW CX, K2 - MOVQ $0x01ff, CX - KMOVW CX, K3 + XORQ AX, AX MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 VXORPS Z0, Z0, Z0 @@ -706,84 +701,75 @@ accumulate_12: VALIGNQ $1, Z0, Z0, Z0 VMOVQ X0, R12 XORQ AX, AX - MOVQ SI, AX - ANDQ $0xffffffff, AX - SHLQ $32, AX + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 SHRQ $32, SI - ADOXQ AX, BX - MOVQ R8, AX - ANDQ $0xffffffff, AX - SHLQ $32, AX + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX SHRQ $32, R8 - ADOXQ AX, DI - ADCXQ SI, DI - MOVQ R10, AX - ANDQ $0xffffffff, AX - SHLQ $32, AX + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 SHRQ $32, R10 - ADOXQ AX, R9 - ADCXQ R8, R9 - MOVQ R12, AX - ANDQ $0xffffffff, AX - SHLQ $32, AX + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 SHRQ $32, R12 - ADOXQ AX, R11 + XORQ AX, AX + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 ADCXQ R10, R11 - MOVQ $0, AX ADOXQ AX, R12 ADCXQ AX, R12 - MOVQ res+0(FP), R14 - MOVQ BX, 0(R14) - MOVQ DI, 8(R14) - MOVQ R9, 16(R14) - MOVQ R11, 24(R14) - RET - MOVQ mu<>(SB), CX - MOVQ R8, AX - SHRQ $32, R13, AX - MULQ CX - MULXQ q<>+0(SB), AX, CX - SUBQ AX, BX - SBBQ CX, SI - MULXQ q<>+16(SB), AX, CX - SBBQ AX, DI - SBBQ CX, R8 - SBBQ $0, R13 - MULXQ q<>+8(SB), AX, CX - SUBQ AX, SI - SBBQ CX, DI - MULXQ q<>+24(SB), AX, CX - SBBQ AX, R8 - SBBQ CX, R13 - MOVQ $0x0a11800000000001, R9 - MOVQ $0x59aa76fed0000001, R10 - MOVQ $0x60b44d1e5c37b001, R11 - MOVQ $0x12ab655e9a2ca556, R12 - MOVQ res+0(FP), R14 - MOVQ BX, 0(R14) - MOVQ SI, 8(R14) - MOVQ DI, 16(R14) - MOVQ R8, 24(R14) - SUBQ R9, BX - SBBQ R10, SI - SBBQ R11, DI - SBBQ R12, R8 - SBBQ $0, R13 - JCS done_9 - MOVQ BX, 0(R14) - MOVQ SI, 8(R14) - MOVQ DI, 16(R14) - MOVQ R8, 24(R14) - SUBQ R9, BX - SBBQ R10, SI - SBBQ R11, DI - SBBQ R12, R8 - SBBQ $0, R13 - JCS done_9 - MOVQ BX, 0(R14) - MOVQ SI, 8(R14) - MOVQ DI, 16(R14) - MOVQ R8, 24(R14) + XORQ AX, AX + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: RET diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 8cfdfd8dc..95415f8f9 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 + const N = 4 * 16 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -745,28 +745,28 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - const maxUint64 = ^uint64(0) - for i := 0; i < N; i++ { - c[i][0] = maxUint64 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } - for i := 2; i < N; i++ { - c[i][0] = 0 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } + // const maxUint64 = ^uint64(0) + // for i := 0; i < N; i++ { + // c[i][0] = maxUint64 + // c[i][1] = 0 + // // c[i][2] = 0 + // c[i][3] = 0 + // } + // // for i := 2; i < N; i++ { + // // c[i][0] = 0 + // // c[i][1] = 0 + // // c[i][2] = 0 + // // c[i][3] = 0 + // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b\n", computed[0]) - fmt.Printf("computed[1]: %64b\n", computed[1]) - fmt.Printf("computed[2]: %64b\n", computed[2]) - fmt.Printf("computed[3]: %64b\n", computed[3]) + fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) + fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) + fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) + fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) // print the sum[0], sum[1] in 64bit binary string fmt.Printf("sum [0]: %64b\n", sum[0]) diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index 4f74166db..f270af8ae 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -257,10 +257,6 @@ func (f *FFAmd64) generateSumVec() { len := f.Pop(®isters) tmp0 := f.Pop(®isters) - t := f.PopN(®isters) - s := f.PopN(®isters) - t4 := f.Pop(®isters) - loop := f.NewLabel("loop") done := f.NewLabel("done") rr1 := f.NewLabel("rr1") @@ -276,18 +272,7 @@ func (f *FFAmd64) generateSumVec() { Z4 := amd64.Register("Z4") X0 := amd64.Register("X0") - K1 := amd64.Register("K1") - K2 := amd64.Register("K2") - K3 := amd64.Register("K3") - - f.MOVQ("$0x1555", tmp0) - f.KMOVW(tmp0, K1) - - f.MOVQ("$0xff80", tmp0) - f.KMOVW(tmp0, K2) - - f.MOVQ("$0x01ff", tmp0) - f.KMOVW(tmp0, K3) + f.XORQ(amd64.AX, amd64.AX) // load arguments f.MOVQ("a+8(FP)", addrA) @@ -316,6 +301,9 @@ func (f *FFAmd64) generateSumVec() { f.CMPB(tmp0, "$3") f.JNE(loop, "== 0; we have 0 remaining elements") + f.Push(®isters, tmp0) // we don't need tmp0 + tmp0 = "" + f.Comment("we have 3 remaining elements") // vpmovzxdq 2*32(PX), %zmm4; vpaddq %zmm4, %zmm0, %zmm0 f.VPMOVZXDQ("2*32("+addrA+")", Z4) @@ -354,207 +342,154 @@ func (f *FFAmd64) generateSumVec() { f.DECQ(len, "decrement n") f.JMP(loop) + f.Push(®isters, len, addrA) // we don't need len + len = "" + addrA = "" + f.LABEL(accumulate) f.VPADDQ(Z1, Z0, Z0) f.VPADDQ(Z3, Z2, Z2) f.VPADDQ(Z2, Z0, Z0) + w0l := f.Pop(®isters) + w0h := f.Pop(®isters) + w1l := f.Pop(®isters) + w1h := f.Pop(®isters) + w2l := f.Pop(®isters) + w2h := f.Pop(®isters) + w3l := f.Pop(®isters) + w3h := f.Pop(®isters) + low0h := f.Pop(®isters) + low1h := f.Pop(®isters) + low2h := f.Pop(®isters) + low3h := f.Pop(®isters) + // Propagate carries - f.VMOVQ(X0, t[0]) + f.VMOVQ(X0, w0l) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, t[1]) + f.VMOVQ(X0, w0h) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, t[2]) + f.VMOVQ(X0, w1l) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, t[3]) + f.VMOVQ(X0, w1h) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, s[0]) + f.VMOVQ(X0, w2l) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, s[1]) + f.VMOVQ(X0, w2h) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, s[2]) + f.VMOVQ(X0, w3l) f.VALIGNQ("$1", Z0, Z0, Z0) - f.VMOVQ(X0, s[3]) - - w0l := t[0] - w0h := t[1] - w1l := t[2] - w1h := t[3] - w2l := s[0] - w2h := s[1] - w3l := s[2] - w3h := s[3] + f.VMOVQ(X0, w3h) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) // r2 = carry + hi(w1h) + w2l + lo(w2h) // r3 = carry + hi(w2h) + w3l + lo(w3h) - r0 := w0l - r1 := w1l - r2 := w2l - r3 := w3l // we need 2 carry so we use ADOXQ and ADCXQ f.XORQ(amd64.AX, amd64.AX) + type hilo struct { + hi, lo amd64.Register + } + for _, v := range []hilo{{w0h, low0h}, {w1h, low1h}, {w2h, low2h}, {w3h, low3h}} { + f.MOVQ(v.hi, v.lo) + f.ANDQ("$0xffffffff", v.lo) + f.SHLQ("$32", v.lo) + f.SHRQ("$32", v.hi) + } - // get low bits of w0h - f.MOVQ(w0h, amd64.AX) - f.ANDQ("$0xffffffff", amd64.AX) - f.SHLQ("$32", amd64.AX) - f.SHRQ("$32", w0h) - + f.XORQ(amd64.AX, amd64.AX) // start the carry chain - f.ADOXQ(amd64.AX, w0l) // w0l is good. - - // get low bits of w1h - f.MOVQ(w1h, amd64.AX) - f.ANDQ("$0xffffffff", amd64.AX) - f.SHLQ("$32", amd64.AX) - f.SHRQ("$32", w1h) + f.ADOXQ(low0h, w0l) - f.ADOXQ(amd64.AX, w1l) + f.ADOXQ(low1h, w1l) f.ADCXQ(w0h, w1l) - // get low bits of w2h - f.MOVQ(w2h, amd64.AX) - f.ANDQ("$0xffffffff", amd64.AX) - f.SHLQ("$32", amd64.AX) - f.SHRQ("$32", w2h) - - f.ADOXQ(amd64.AX, w2l) + f.ADOXQ(low2h, w2l) f.ADCXQ(w1h, w2l) - // get low bits of w3h - f.MOVQ(w3h, amd64.AX) - f.ANDQ("$0xffffffff", amd64.AX) - f.SHLQ("$32", amd64.AX) - f.SHRQ("$32", w3h) - - f.ADOXQ(amd64.AX, w3l) + f.ADOXQ(low3h, w3l) f.ADCXQ(w2h, w3l) + + f.ADOXQ(amd64.AX, w3h) + f.ADCXQ(amd64.AX, w3h) + + r0 := w0l + r1 := w1l + r2 := w2l + r3 := w3l r4 := w3h - f.MOVQ("$0", amd64.AX) - f.ADOXQ(amd64.AX, r4) - f.ADCXQ(amd64.AX, r4) - - // // we use AX for low 32bits - // f.MOVQ(t[1], amd64.AX) - // f.ANDQ("$0xffffffff", amd64.AX) - // f.SHRQ("$32", t[1]) - - // // start the carry chain - // f.ADDQ(amd64.AX, t[0]) // t0 is good. - // // now t1, we have to add t1 + low(t2) - - // // // Propagate carries - // // mov $8, %eax - // // valignd $1, %zmm3, %zmm0, %zmm3{%k2}{z} // Shift lowest dword of zmm0 into zmm3 - // f.MOVQ("$8", tmp0) - // f.VALIGND("$1", Z3, Z0, K2, Z3) - - // f.LABEL(propagate) - // f.VPSRLQ("$32", Z0, Z1) - // f.VALIGND("$2", Z0, Z0, K1, Z0) - // f.VPADDQ(Z1, Z0, Z0) - // f.VALIGND("$1", Z3, Z0, K2, Z3) - - // f.DECQ(tmp0) - // f.JNE(propagate) - - // // The top 9 dwords of zmm3 now contain the sum - // // we shift by 224 bits to get the result in the low 32bytes - - // // // Move intermediate result to integer registers - // // The top 9 dwords of zmm3 now contain the sum. Copy them to the low end of zmm0. - // // valignd $7, %zmm3, %zmm3, %zmm0{%k3}{z} - // // // Copy to integer registers - // // vmovq %xmm0, T0; valignq $1, %zmm0, %zmm0, %zmm0 - // // vmovq %xmm0, T1; valignq $1, %zmm0, %zmm0, %zmm0 - // // vmovq %xmm0, T2; valignq $1, %zmm0, %zmm0, %zmm0 - // // vmovq %xmm0, T3; valignq $1, %zmm0, %zmm0, %zmm0 - // // vmovq %xmm0, T4 - - // f.VALIGND("$7", Z3, Z3, K3, Z0) - - // f.VMOVQ(X0, t[0]) - // f.VALIGNQ("$1", Z0, Z0, Z0) - // f.VMOVQ(X0, t[1]) - // f.VALIGNQ("$1", Z0, Z0, Z0) - // f.VMOVQ(X0, t[2]) - // f.VALIGNQ("$1", Z0, Z0, Z0) - // f.VMOVQ(X0, t[3]) - // f.VALIGNQ("$1", Z0, Z0, Z0) - // f.VMOVQ(X0, t4) - - f.MOVQ("res+0(FP)", addrA) r := []amd64.Register{r0, r1, r2, r3} - f.Mov(r, addrA) - - f.RET() + // we don't need w0h, w1h, w2h anymore + f.Push(®isters, w0h, w1h, w2h) + w0h = "" + w1h = "" + w2h = "" + // we don't need the low bits anymore + f.Push(®isters, low0h, low1h, low2h, low3h) + low0h = "" + low1h = "" + low2h = "" + low3h = "" // Reduce using single-word Barrett // q1 is low 32 bits of T4 and high 32 bits of T3 // movq T3, %rax // shrd $32, T4, %rax // mulq MU // Multiply by mu. q2 in rdx:rax, q3 in rdx - f.MOVQ(f.mu(), tmp0) - f.MOVQ(t[3], amd64.AX) - f.SHRQw("$32", t4, amd64.AX) - f.MULQ(tmp0) - - // Subtract r2 from r1 - // mulx 0*8(PM), PL, PH; sub PL, T0; sbb PH, T1; - // mulx 2*8(PM), PL, PH; sbb PL, T2; sbb PH, T3; sbb $0, T4 - // mulx 1*8(PM), PL, PH; sub PL, T1; sbb PH, T2; - // mulx 3*8(PM), PL, PH; sbb PL, T3; sbb PH, T4 - f.MULXQ(f.qAt(0), amd64.AX, tmp0) - f.SUBQ(amd64.AX, t[0]) - f.SBBQ(tmp0, t[1]) - - f.MULXQ(f.qAt(2), amd64.AX, tmp0) - f.SBBQ(amd64.AX, t[2]) - f.SBBQ(tmp0, t[3]) - f.SBBQ("$0", t4) - - f.MULXQ(f.qAt(1), amd64.AX, tmp0) - f.SUBQ(amd64.AX, t[1]) - f.SBBQ(tmp0, t[2]) - - f.MULXQ(f.qAt(3), amd64.AX, tmp0) - f.SBBQ(amd64.AX, t[3]) - f.SBBQ(tmp0, t4) - - // Two conditional subtractions to guarantee canonicity of the result - // substract modulus from t - f.Mov(f.Q, s) - - f.MOVQ("res+0(FP)", addrA) - f.Mov(t, addrA) - - f.Sub(s, t) - f.SBBQ("$0", t4) - // if borrow, skip to end - f.JCS(done) + mu := f.Pop(®isters) - f.Mov(t, addrA) - f.Sub(s, t) - f.SBBQ("$0", t4) - // if borrow, skip to end - f.JCS(done) + f.XORQ(amd64.AX, amd64.AX) + f.MOVQ(f.mu(), mu) + f.MOVQ(r3, amd64.AX) + f.SHRQw("$32", r4, amd64.AX) + f.MULQ(mu) - f.Mov(t, addrA) + f.MULXQ(f.qAt(0), amd64.AX, mu) + f.SUBQ(amd64.AX, r0) + f.SBBQ(mu, r1) - // save t into res + f.MULXQ(f.qAt(2), amd64.AX, mu) + f.SBBQ(amd64.AX, r2) + f.SBBQ(mu, r3) + f.SBBQ("$0", r4) - f.LABEL(done) + f.MULXQ(f.qAt(1), amd64.AX, mu) + f.SUBQ(amd64.AX, r1) + f.SBBQ(mu, r2) + + f.MULXQ(f.qAt(3), amd64.AX, mu) + f.SBBQ(amd64.AX, r3) + f.SBBQ(mu, r4) + + addrRes := mu + f.MOVQ("res+0(FP)", addrRes) + f.Mov(r, addrRes) - // save t into res + // sub modulus + f.SUBQ(f.qAt(0), r0) + f.SBBQ(f.qAt(1), r1) + f.SBBQ(f.qAt(2), r2) + f.SBBQ(f.qAt(3), r3) + f.SBBQ("$0", r4) - // f.Mov(t, addrA) + // if borrow, we skip to the end + f.JCS(done) + f.Mov(r, addrRes) + f.SUBQ(f.qAt(0), r0) + f.SBBQ(f.qAt(1), r1) + f.SBBQ(f.qAt(2), r2) + f.SBBQ(f.qAt(3), r3) + f.SBBQ("$0", r4) + + // if borrow, we skip to the end + f.JCS(done) + f.Mov(r, addrRes) + + f.LABEL(done) f.RET() - f.Push(®isters, addrA, len, tmp0, t4) - f.Push(®isters, t...) - f.Push(®isters, s...) + f.Push(®isters, mu) + f.Push(®isters, w0l, w1l, w2l, w3l, w3h) } diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 5dd15cdcd..80b041649 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 4 + const N = 4*16 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -766,28 +766,28 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { // Vector sum var sum {{.ElementName}} - const maxUint64 = ^uint64(0) - for i := 0; i < N; i++ { - c[i][0] = maxUint64 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } - for i := 2; i < N; i++ { - c[i][0] = 0 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } + // const maxUint64 = ^uint64(0) + // for i := 0; i < N; i++ { + // c[i][0] = maxUint64 + // c[i][1] = 0 + // // c[i][2] = 0 + // c[i][3] = 0 + // } + // // for i := 2; i < N; i++ { + // // c[i][0] = 0 + // // c[i][1] = 0 + // // c[i][2] = 0 + // // c[i][3] = 0 + // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b\n", computed[0]) - fmt.Printf("computed[1]: %64b\n", computed[1]) - fmt.Printf("computed[2]: %64b\n", computed[2]) - fmt.Printf("computed[3]: %64b\n", computed[3]) + fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0]==sum[0]) + fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1]==sum[1]) + fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2]==sum[2]) + fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3]==sum[3]) // print the sum[0], sum[1] in 64bit binary string fmt.Printf("sum [0]: %64b\n", sum[0]) diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 95857eb99..29e13113f 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 + const N = 4 * 16 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -690,28 +690,28 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - const maxUint64 = ^uint64(0) - for i := 0; i < N; i++ { - c[i][0] = maxUint64 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } - for i := 2; i < N; i++ { - c[i][0] = 0 - c[i][1] = 0 - c[i][2] = 0 - c[i][3] = 0 - } + // const maxUint64 = ^uint64(0) + // for i := 0; i < N; i++ { + // c[i][0] = maxUint64 + // c[i][1] = 0 + // // c[i][2] = 0 + // c[i][3] = 0 + // } + // // for i := 2; i < N; i++ { + // // c[i][0] = 0 + // // c[i][1] = 0 + // // c[i][2] = 0 + // // c[i][3] = 0 + // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b\n", computed[0]) - fmt.Printf("computed[1]: %64b\n", computed[1]) - fmt.Printf("computed[2]: %64b\n", computed[2]) - fmt.Printf("computed[3]: %64b\n", computed[3]) + fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) + fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) + fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) + fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) // print the sum[0], sum[1] in 64bit binary string fmt.Printf("sum [0]: %64b\n", sum[0]) From ce4ade2bbeec9629293f92a605756b3f5a597ce5 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sat, 21 Sep 2024 23:00:25 +0000 Subject: [PATCH 04/14] checkpoint --- ecc/bls12-377/fp/element_test.go | 28 ++----------------- ecc/bls12-377/fr/element_ops_amd64.s | 9 ++++-- ecc/bls12-377/fr/element_test.go | 28 ++----------------- field/generator/asm/amd64/element_vec.go | 18 ++++++++---- .../internal/templates/element/tests.go | 28 ++----------------- field/goldilocks/element_test.go | 28 ++----------------- 6 files changed, 26 insertions(+), 113 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index d289c9c24..236abd355 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 * 16 + const N = 4*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,34 +749,10 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - // const maxUint64 = ^uint64(0) - // for i := 0; i < N; i++ { - // c[i][0] = maxUint64 - // c[i][1] = 0 - // // c[i][2] = 0 - // c[i][3] = 0 - // } - // // for i := 2; i < N; i++ { - // // c[i][0] = 0 - // // c[i][1] = 0 - // // c[i][2] = 0 - // // c[i][3] = 0 - // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } - // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) - fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) - fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) - fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) - - // print the sum[0], sum[1] in 64bit binary string - fmt.Printf("sum [0]: %64b\n", sum[0]) - fmt.Printf("sum [1]: %64b\n", sum[1]) - fmt.Printf("sum [2]: %64b\n", sum[2]) - fmt.Printf("sum [3]: %64b\n", sum[3]) assert.True(sum.Equal(&computed), "Vector sum failed") } @@ -784,7 +760,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1024 * 10 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 28af7d7a2..195555ab0 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -643,11 +643,11 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ R15, CX ANDQ $3, CX SHRQ $2, R15 - CMPB CX, $1 + CMPQ CX, $1 JEQ rr1_10 // we have 1 remaining element - CMPB CX, $2 + CMPQ CX, $2 JEQ rr2_11 // we have 2 remaining elements - CMPB CX, $3 + CMPQ CX, $3 JNE loop_8 // == 0; we have 0 remaining elements // we have 3 remaining elements @@ -663,6 +663,9 @@ rr1_10: // we have 1 remaining element VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 loop_8: TESTQ R15, R15 diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 95415f8f9..52098e507 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 * 16 + const N = 4*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -745,34 +745,10 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - // const maxUint64 = ^uint64(0) - // for i := 0; i < N; i++ { - // c[i][0] = maxUint64 - // c[i][1] = 0 - // // c[i][2] = 0 - // c[i][3] = 0 - // } - // // for i := 2; i < N; i++ { - // // c[i][0] = 0 - // // c[i][1] = 0 - // // c[i][2] = 0 - // // c[i][3] = 0 - // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } - // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) - fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) - fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) - fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) - - // print the sum[0], sum[1] in 64bit binary string - fmt.Printf("sum [0]: %64b\n", sum[0]) - fmt.Printf("sum [1]: %64b\n", sum[1]) - fmt.Printf("sum [2]: %64b\n", sum[2]) - fmt.Printf("sum [3]: %64b\n", sum[3]) assert.True(sum.Equal(&computed), "Vector sum failed") } @@ -780,7 +756,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1024 * 10 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index f270af8ae..c51c1d789 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -292,18 +292,15 @@ func (f *FFAmd64) generateSumVec() { f.SHRQ("$2", len) // len = n / 4 // if len % 4 != 0, we need to handle the remaining elements - f.CMPB(tmp0, "$1") + f.CMPQ(tmp0, "$1") f.JEQ(rr1, "we have 1 remaining element") - f.CMPB(tmp0, "$2") + f.CMPQ(tmp0, "$2") f.JEQ(rr2, "we have 2 remaining elements") - f.CMPB(tmp0, "$3") + f.CMPQ(tmp0, "$3") f.JNE(loop, "== 0; we have 0 remaining elements") - f.Push(®isters, tmp0) // we don't need tmp0 - tmp0 = "" - f.Comment("we have 3 remaining elements") // vpmovzxdq 2*32(PX), %zmm4; vpaddq %zmm4, %zmm0, %zmm0 f.VPMOVZXDQ("2*32("+addrA+")", Z4) @@ -321,6 +318,15 @@ func (f *FFAmd64) generateSumVec() { f.VPMOVZXDQ("0*32("+addrA+")", Z4) f.VPADDQ(Z4, Z2, Z2) + // mul $32 by tmp0 + // TODO use better instructions + f.MOVQ("$32", amd64.DX) + f.IMULQ(tmp0, amd64.DX) + f.ADDQ(amd64.DX, addrA) + + f.Push(®isters, tmp0) // we don't need tmp0 + tmp0 = "" + f.LABEL(loop) f.TESTQ(len, len) f.JEQ(accumulate, "n == 0, we are going to accumulate") diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 80b041649..5d9d23a75 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 4*16 + const N = 4*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -766,34 +766,10 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { // Vector sum var sum {{.ElementName}} - // const maxUint64 = ^uint64(0) - // for i := 0; i < N; i++ { - // c[i][0] = maxUint64 - // c[i][1] = 0 - // // c[i][2] = 0 - // c[i][3] = 0 - // } - // // for i := 2; i < N; i++ { - // // c[i][0] = 0 - // // c[i][1] = 0 - // // c[i][2] = 0 - // // c[i][3] = 0 - // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } - // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0]==sum[0]) - fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1]==sum[1]) - fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2]==sum[2]) - fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3]==sum[3]) - - // print the sum[0], sum[1] in 64bit binary string - fmt.Printf("sum [0]: %64b\n", sum[0]) - fmt.Printf("sum [1]: %64b\n", sum[1]) - fmt.Printf("sum [2]: %64b\n", sum[2]) - fmt.Printf("sum [3]: %64b\n", sum[3]) assert.True(sum.Equal(&computed), "Vector sum failed") } @@ -801,7 +777,7 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1024 * 10 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 29e13113f..64b1dddd8 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4 * 16 + const N = 4*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -690,34 +690,10 @@ func TestElementVecOps(t *testing.T) { // Vector sum var sum Element - // const maxUint64 = ^uint64(0) - // for i := 0; i < N; i++ { - // c[i][0] = maxUint64 - // c[i][1] = 0 - // // c[i][2] = 0 - // c[i][3] = 0 - // } - // // for i := 2; i < N; i++ { - // // c[i][0] = 0 - // // c[i][1] = 0 - // // c[i][2] = 0 - // // c[i][3] = 0 - // // } computed := c.Sum() for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } - // print computed[0], computed[1] in 64bit binary string - fmt.Printf("computed[0]: %64b %v \n", computed[0], computed[0] == sum[0]) - fmt.Printf("computed[1]: %64b %v \n", computed[1], computed[1] == sum[1]) - fmt.Printf("computed[2]: %64b %v \n", computed[2], computed[2] == sum[2]) - fmt.Printf("computed[3]: %64b %v \n", computed[3], computed[3] == sum[3]) - - // print the sum[0], sum[1] in 64bit binary string - fmt.Printf("sum [0]: %64b\n", sum[0]) - fmt.Printf("sum [1]: %64b\n", sum[1]) - fmt.Printf("sum [2]: %64b\n", sum[2]) - fmt.Printf("sum [3]: %64b\n", sum[3]) assert.True(sum.Equal(&computed), "Vector sum failed") } @@ -725,7 +701,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1024 * 10 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) From 847a6dff9eb6bb8f7f32398af5c7d051046ddf84 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 22 Sep 2024 02:29:12 +0000 Subject: [PATCH 05/14] feat: add vec.Sum AVX512 --- ecc/bls12-377/fp/element_test.go | 4 +- ecc/bls12-377/fr/asm.go | 6 +- ecc/bls12-377/fr/asm_noadx.go | 6 +- ecc/bls12-377/fr/element_ops_amd64.go | 7 +- ecc/bls12-377/fr/element_ops_amd64.s | 242 +++++++++++------- ecc/bls12-377/fr/element_test.go | 4 +- ecc/bls12-381/fp/element_test.go | 5 +- ecc/bls12-381/fr/asm.go | 6 +- ecc/bls12-381/fr/asm_noadx.go | 6 +- ecc/bls12-381/fr/element_mul_amd64.s | 3 + ecc/bls12-381/fr/element_ops_amd64.go | 7 +- ecc/bls12-381/fr/element_ops_amd64.s | 227 +++++++++++++--- ecc/bls12-381/fr/element_test.go | 5 +- ecc/bls24-315/fp/element_test.go | 5 +- ecc/bls24-315/fr/asm.go | 6 +- ecc/bls24-315/fr/asm_noadx.go | 6 +- ecc/bls24-315/fr/element_mul_amd64.s | 3 + ecc/bls24-315/fr/element_ops_amd64.go | 7 +- ecc/bls24-315/fr/element_ops_amd64.s | 227 +++++++++++++--- ecc/bls24-315/fr/element_test.go | 5 +- ecc/bls24-317/fp/element_test.go | 5 +- ecc/bls24-317/fr/asm.go | 6 +- ecc/bls24-317/fr/asm_noadx.go | 6 +- ecc/bls24-317/fr/element_mul_amd64.s | 3 + ecc/bls24-317/fr/element_ops_amd64.go | 7 +- ecc/bls24-317/fr/element_ops_amd64.s | 227 +++++++++++++--- ecc/bls24-317/fr/element_test.go | 5 +- ecc/bn254/fp/asm.go | 6 +- ecc/bn254/fp/asm_noadx.go | 6 +- ecc/bn254/fp/element_mul_amd64.s | 3 + ecc/bn254/fp/element_ops_amd64.go | 7 +- ecc/bn254/fp/element_ops_amd64.s | 223 ++++++++++++---- ecc/bn254/fp/element_test.go | 5 +- ecc/bn254/fr/asm.go | 6 +- ecc/bn254/fr/asm_noadx.go | 6 +- ecc/bn254/fr/element_mul_amd64.s | 3 + ecc/bn254/fr/element_ops_amd64.go | 7 +- ecc/bn254/fr/element_ops_amd64.s | 223 ++++++++++++---- ecc/bn254/fr/element_test.go | 5 +- ecc/bn254/internal/fptower/e2_amd64.s | 3 + ecc/bw6-633/fp/element_test.go | 5 +- ecc/bw6-633/fr/element_test.go | 5 +- ecc/bw6-761/fp/element_test.go | 5 +- ecc/bw6-761/fr/element_test.go | 5 +- ecc/secp256k1/fp/element_test.go | 5 +- ecc/secp256k1/fr/element_test.go | 5 +- ecc/stark-curve/fp/asm.go | 6 +- ecc/stark-curve/fp/asm_noadx.go | 6 +- ecc/stark-curve/fp/element_mul_amd64.s | 3 + ecc/stark-curve/fp/element_ops_amd64.go | 7 +- ecc/stark-curve/fp/element_ops_amd64.s | 227 +++++++++++++--- ecc/stark-curve/fp/element_test.go | 5 +- ecc/stark-curve/fr/asm.go | 6 +- ecc/stark-curve/fr/asm_noadx.go | 6 +- ecc/stark-curve/fr/element_mul_amd64.s | 3 + ecc/stark-curve/fr/element_ops_amd64.go | 7 +- ecc/stark-curve/fr/element_ops_amd64.s | 227 +++++++++++++--- ecc/stark-curve/fr/element_test.go | 5 +- field/generator/asm/amd64/element_vec.go | 124 ++++++--- .../internal/templates/element/asm.go | 8 + .../internal/templates/element/ops_asm.go | 7 +- .../internal/templates/element/tests.go | 4 +- field/goldilocks/element_test.go | 4 +- go.mod | 4 +- go.sum | 4 +- internal/generator/main.go | 3 - 66 files changed, 1772 insertions(+), 473 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 236abd355..8376c41e2 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4*16 + 4 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -760,7 +760,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 * 10 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls12-377/fr/asm.go b/ecc/bls12-377/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/bls12-377/fr/asm.go +++ b/ecc/bls12-377/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bls12-377/fr/asm_noadx.go b/ecc/bls12-377/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/bls12-377/fr/asm_noadx.go +++ b/ecc/bls12-377/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 195555ab0..0fdfd2104 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -631,35 +631,61 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - XORQ AX, AX - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 - TESTQ R15, R15 - JEQ done_9 // n == 0, we are done - MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ rr1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ rr2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop_8 // == 0; we have 0 remaining elements + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements VPMOVZXDQ 2*32(R14), Z4 VPADDQ Z4, Z0, Z0 -rr2_11: +r2_11: // we have 2 remaining elements VPMOVZXDQ 1*32(R14), Z4 VPADDQ Z4, Z1, Z1 -rr1_10: +r1_10: // we have 1 remaining element VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z2, Z2 @@ -667,7 +693,7 @@ rr1_10: IMULQ CX, DX ADDQ DX, R14 -loop_8: +loop4by4_8: TESTQ R15, R15 JEQ accumulate_12 // n == 0, we are going to accumulate VPMOVZXDQ 0*32(R14), Z4 @@ -681,13 +707,24 @@ loop_8: // increment pointers to visit next 4 elements ADDQ $128, R14 - DECQ R15 // decrement n - JMP loop_8 + DECQ R15 // decrement n + JMP loop4by4_8 accumulate_12: - VPADDQ Z1, Z0, Z0 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 VMOVQ X0, BX VALIGNQ $1, Z0, Z0, Z0 VMOVQ X0, SI @@ -703,76 +740,97 @@ accumulate_12: VMOVQ X0, R11 VALIGNQ $1, Z0, Z0, Z0 VMOVQ X0, R12 - XORQ AX, AX - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 - XORQ AX, AX - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - XORQ AX, AX - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: RET diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 52098e507..b576a09f4 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4*16 + 4 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -756,7 +756,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 * 10 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index da7263c5a..88974f2cd 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -753,13 +753,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls12-381/fr/asm.go b/ecc/bls12-381/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/bls12-381/fr/asm.go +++ b/ecc/bls12-381/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bls12-381/fr/asm_noadx.go b/ecc/bls12-381/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/bls12-381/fr/asm_noadx.go +++ b/ecc/bls12-381/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bls12-381/fr/element_mul_amd64.s b/ecc/bls12-381/fr/element_mul_amd64.s index 396d990b7..36064570f 100644 --- a/ecc/bls12-381/fr/element_mul_amd64.s +++ b/ecc/bls12-381/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xfffffffeffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x00000002355094ed +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index b9a186e46..034d36141 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xfffffffeffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x00000002355094ed +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,46 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - VXORQ ZMM0, ZMM0, ZMM0 - VMOVDQA64 ZMM0, ZMM1 - VMOVDQA64 ZMM0, ZMM2 - VMOVDQA64 ZMM0, ZMM3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPQ $1, CX - JEQ r1_10 // we have 1 remaining element - CMPQ $2, CX - JEQ r2_11 // we have 2 remaining elements - CMPQ $3, CX - JNE loop_8 // == 0; we have 0 remaining elements + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), ZMM4 - VPADDQ ZMM4, ZMM0, ZMM0 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), ZMM4 - VPADDQ ZMM4, ZMM1, ZMM1 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), ZMM4 - VPADDQ ZMM4, ZMM2, ZMM2 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - -loop_8: - // increment pointers to visit next element - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 + +loop4by4_8: + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 + +accumulate_12: + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index a0ce1d6ea..ff0b3c188 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 31f018b1b..bdf5071ad 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -751,13 +751,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls24-315/fr/asm.go b/ecc/bls24-315/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/bls24-315/fr/asm.go +++ b/ecc/bls24-315/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bls24-315/fr/asm_noadx.go b/ecc/bls24-315/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/bls24-315/fr/asm_noadx.go +++ b/ecc/bls24-315/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bls24-315/fr/element_mul_amd64.s b/ecc/bls24-315/fr/element_mul_amd64.s index d028fed20..2afbbd3fe 100644 --- a/ecc/bls24-315/fr/element_mul_amd64.s +++ b/ecc/bls24-315/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000000a112d9c09 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 8f4991b50..9a05cb98f 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000000a112d9c09 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,46 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - VXORQ ZMM0, ZMM0, ZMM0 - VMOVDQA64 ZMM0, ZMM1 - VMOVDQA64 ZMM0, ZMM2 - VMOVDQA64 ZMM0, ZMM3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPQ $1, CX - JEQ r1_10 // we have 1 remaining element - CMPQ $2, CX - JEQ r2_11 // we have 2 remaining elements - CMPQ $3, CX - JNE loop_8 // == 0; we have 0 remaining elements + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), ZMM4 - VPADDQ ZMM4, ZMM0, ZMM0 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), ZMM4 - VPADDQ ZMM4, ZMM1, ZMM1 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), ZMM4 - VPADDQ ZMM4, ZMM2, ZMM2 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - -loop_8: - // increment pointers to visit next element - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 + +loop4by4_8: + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 + +accumulate_12: + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 8e3b19942..007beea50 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 76758acc7..9a20746bc 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -751,13 +751,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bls24-317/fr/asm.go b/ecc/bls24-317/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/bls24-317/fr/asm.go +++ b/ecc/bls24-317/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bls24-317/fr/asm_noadx.go b/ecc/bls24-317/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/bls24-317/fr/asm_noadx.go +++ b/ecc/bls24-317/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bls24-317/fr/element_mul_amd64.s b/ecc/bls24-317/fr/element_mul_amd64.s index 6e58b40d6..77d9a3fc4 100644 --- a/ecc/bls24-317/fr/element_mul_amd64.s +++ b/ecc/bls24-317/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xefffffffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x00000003c0421687 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index ba8afdd99..ead13527b 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xefffffffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x00000003c0421687 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,46 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - VXORQ ZMM0, ZMM0, ZMM0 - VMOVDQA64 ZMM0, ZMM1 - VMOVDQA64 ZMM0, ZMM2 - VMOVDQA64 ZMM0, ZMM3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPQ $1, CX - JEQ r1_10 // we have 1 remaining element - CMPQ $2, CX - JEQ r2_11 // we have 2 remaining elements - CMPQ $3, CX - JNE loop_8 // == 0; we have 0 remaining elements + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), ZMM4 - VPADDQ ZMM4, ZMM0, ZMM0 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), ZMM4 - VPADDQ ZMM4, ZMM1, ZMM1 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), ZMM4 - VPADDQ ZMM4, ZMM2, ZMM2 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - -loop_8: - // increment pointers to visit next element - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 + +loop4by4_8: + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 + +accumulate_12: + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 85c80207c..881f55ff3 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bn254/fp/asm.go b/ecc/bn254/fp/asm.go index 0481989ec..bb604d40b 100644 --- a/ecc/bn254/fp/asm.go +++ b/ecc/bn254/fp/asm.go @@ -22,6 +22,8 @@ package fp import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bn254/fp/asm_noadx.go b/ecc/bn254/fp/asm_noadx.go index 92f8cc0f4..090d05f05 100644 --- a/ecc/bn254/fp/asm_noadx.go +++ b/ecc/bn254/fp/asm_noadx.go @@ -23,6 +23,8 @@ package fp // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bn254/fp/element_mul_amd64.s b/ecc/bn254/fp/element_mul_amd64.s index 9357a21d7..95e95e6fa 100644 --- a/ecc/bn254/fp/element_mul_amd64.s +++ b/ecc/bn254/fp/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x87d20782e4866389 GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x000000054a474626 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index b7e3b3542..3963ace9d 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 6cb600960..87880fe50 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x87d20782e4866389 GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x000000054a474626 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,76 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ $0x1555, CX - KMOVW CX, K1 - MOVQ $0xff80, CX - KMOVW CX, K2 - MOVQ $0x01ff, CX - KMOVW CX, K3 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPB CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPB CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPB CX, $3 - JNE loop_8 // == 0; we have 0 remaining elements + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), Z4 + VPMOVZXDQ 2*32(R14), Z4 VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), Z4 + VPMOVZXDQ 1*32(R14), Z4 VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), Z4 + VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 -loop_8: - TESTQ DX, DX +loop4by4_8: + TESTQ R15, R15 JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(AX), Z4 + VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(AX), Z4 + VPMOVZXDQ 1*32(R14), Z4 VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(AX), Z4 + VPMOVZXDQ 2*32(R14), Z4 VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(AX), Z4 + VPMOVZXDQ 3*32(R14), Z4 VPADDQ Z4, Z3, Z3 // increment pointers to visit next 4 elements - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 accumulate_12: - VPADDQ Z1, Z0, Z0 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 - MOVQ $8, CX - VALIGND $1, Z3, Z0, K2, Z3 - -propagate_13: - VPSRLQ $32, Z0, Z1 - VALIGND $2, Z0, Z0, K1, Z0 - VPADDQ Z1, Z0, Z0 - VALIGND $1, Z3, Z0, K2, Z3 - DECQ CX - JNE propagate_13 + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 2ae7618d2..14847a533 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bn254/fr/asm.go b/ecc/bn254/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/bn254/fr/asm.go +++ b/ecc/bn254/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/bn254/fr/asm_noadx.go b/ecc/bn254/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/bn254/fr/asm_noadx.go +++ b/ecc/bn254/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/bn254/fr/element_mul_amd64.s b/ecc/bn254/fr/element_mul_amd64.s index 4a9321837..98e98ef6b 100644 --- a/ecc/bn254/fr/element_mul_amd64.s +++ b/ecc/bn254/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xc2e1f593efffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x000000054a474626 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index ede44f461..49581b69e 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xc2e1f593efffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x000000054a474626 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,76 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ $0x1555, CX - KMOVW CX, K1 - MOVQ $0xff80, CX - KMOVW CX, K2 - MOVQ $0x01ff, CX - KMOVW CX, K3 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPB CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPB CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPB CX, $3 - JNE loop_8 // == 0; we have 0 remaining elements + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), Z4 + VPMOVZXDQ 2*32(R14), Z4 VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), Z4 + VPMOVZXDQ 1*32(R14), Z4 VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), Z4 + VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 -loop_8: - TESTQ DX, DX +loop4by4_8: + TESTQ R15, R15 JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(AX), Z4 + VPMOVZXDQ 0*32(R14), Z4 VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(AX), Z4 + VPMOVZXDQ 1*32(R14), Z4 VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(AX), Z4 + VPMOVZXDQ 2*32(R14), Z4 VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(AX), Z4 + VPMOVZXDQ 3*32(R14), Z4 VPADDQ Z4, Z3, Z3 // increment pointers to visit next 4 elements - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 accumulate_12: - VPADDQ Z1, Z0, Z0 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 - MOVQ $8, CX - VALIGND $1, Z3, Z0, K2, Z3 - -propagate_13: - VPSRLQ $32, Z0, Z1 - VALIGND $2, Z0, Z0, K1, Z0 - VPADDQ Z1, Z0, Z0 - VALIGND $1, Z3, Z0, K2, Z3 - DECQ CX - JNE propagate_13 + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 059ab534f..8c8e431dc 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bn254/internal/fptower/e2_amd64.s b/ecc/bn254/internal/fptower/e2_amd64.s index 43ffb7f16..4e69f4c3f 100644 --- a/ecc/bn254/internal/fptower/e2_amd64.s +++ b/ecc/bn254/internal/fptower/e2_amd64.s @@ -25,6 +25,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0x87d20782e4866389 GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x000000054a474626 +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 103721a07..73b3146b6 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -722,7 +722,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -761,13 +761,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index b12fba199..a4b82539b 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -751,13 +751,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index 083a0cda2..d1f404566 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -726,7 +726,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -765,13 +765,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 22c72ba04..90d162aee 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -753,13 +753,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 0781f7dd3..4cedc9a12 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -747,13 +747,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index 85b23645b..f97f76998 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -747,13 +747,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/stark-curve/fp/asm.go b/ecc/stark-curve/fp/asm.go index 0481989ec..bb604d40b 100644 --- a/ecc/stark-curve/fp/asm.go +++ b/ecc/stark-curve/fp/asm.go @@ -22,6 +22,8 @@ package fp import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/stark-curve/fp/asm_noadx.go b/ecc/stark-curve/fp/asm_noadx.go index 92f8cc0f4..090d05f05 100644 --- a/ecc/stark-curve/fp/asm_noadx.go +++ b/ecc/stark-curve/fp/asm_noadx.go @@ -23,6 +23,8 @@ package fp // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/stark-curve/fp/element_mul_amd64.s b/ecc/stark-curve/fp/element_mul_amd64.s index fab328c86..36bbb8a76 100644 --- a/ecc/stark-curve/fp/element_mul_amd64.s +++ b/ecc/stark-curve/fp/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xffffffffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000001fffffffff +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index b7e3b3542..3963ace9d 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index bd1ce4e36..d65b7f196 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xffffffffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000001fffffffff +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,46 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - VXORQ ZMM0, ZMM0, ZMM0 - VMOVDQA64 ZMM0, ZMM1 - VMOVDQA64 ZMM0, ZMM2 - VMOVDQA64 ZMM0, ZMM3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPQ $1, CX - JEQ r1_10 // we have 1 remaining element - CMPQ $2, CX - JEQ r2_11 // we have 2 remaining elements - CMPQ $3, CX - JNE loop_8 // == 0; we have 0 remaining elements + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), ZMM4 - VPADDQ ZMM4, ZMM0, ZMM0 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), ZMM4 - VPADDQ ZMM4, ZMM1, ZMM1 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), ZMM4 - VPADDQ ZMM4, ZMM2, ZMM2 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - -loop_8: - // increment pointers to visit next element - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 + +loop4by4_8: + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 + +accumulate_12: + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 8da5b0e60..c51d1708a 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/ecc/stark-curve/fr/asm.go b/ecc/stark-curve/fr/asm.go index da061913b..c756b153a 100644 --- a/ecc/stark-curve/fr/asm.go +++ b/ecc/stark-curve/fr/asm.go @@ -22,6 +22,8 @@ package fr import "golang.org/x/sys/cpu" var ( - supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 - _ = supportAdx + supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + _ = supportAdx + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 ) diff --git a/ecc/stark-curve/fr/asm_noadx.go b/ecc/stark-curve/fr/asm_noadx.go index 7f52ffa19..a116cd66e 100644 --- a/ecc/stark-curve/fr/asm_noadx.go +++ b/ecc/stark-curve/fr/asm_noadx.go @@ -23,6 +23,8 @@ package fr // certain errors (like fatal error: missing stackmap) // this ensures we test all asm path. var ( - supportAdx = false - _ = supportAdx + supportAdx = false + _ = supportAdx + supportAvx512 = false + _ = supportAvx512 ) diff --git a/ecc/stark-curve/fr/element_mul_amd64.s b/ecc/stark-curve/fr/element_mul_amd64.s index 8eb931e77..f773f8d0d 100644 --- a/ecc/stark-curve/fr/element_mul_amd64.s +++ b/ecc/stark-curve/fr/element_mul_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xbb6b3c4ce8bde631 GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000001fffffffff +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index 88c2d0c58..fe4606378 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -83,7 +83,12 @@ func scalarMulVec(res, a, b *Element, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res Element) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16 * 7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index ee525581c..014740d05 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -27,6 +27,9 @@ GLOBL q<>(SB), (RODATA+NOPTR), $32 // qInv0 q'[0] DATA qInv0<>(SB)/8, $0xbb6b3c4ce8bde631 GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 +// Mu +DATA mu<>(SB)/8, $0x0000001fffffffff +GLOBL mu<>(SB), (RODATA+NOPTR), $8 #define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ MOVQ ra0, rb0; \ @@ -628,46 +631,206 @@ noAdx_5: // sumVec(res, a *Element, n uint64) res = sum(a[0...n]) TEXT ·sumVec(SB), NOSPLIT, $0-24 - MOVQ a+8(FP), AX - MOVQ n+16(FP), DX - VXORQ ZMM0, ZMM0, ZMM0 - VMOVDQA64 ZMM0, ZMM1 - VMOVDQA64 ZMM0, ZMM2 - VMOVDQA64 ZMM0, ZMM3 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - MOVQ DX, CX - ANDQ $3, CX - SHRQ $2, DX - CMPQ $1, CX - JEQ r1_10 // we have 1 remaining element - CMPQ $2, CX - JEQ r2_11 // we have 2 remaining elements - CMPQ $3, CX - JNE loop_8 // == 0; we have 0 remaining elements + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + + // n % 4 -> CX + // n / 4 -> R15 + MOVQ R15, CX + ANDQ $3, CX + SHRQ $2, R15 + CMPQ CX, $1 + JEQ r1_10 // we have 1 remaining element + CMPQ CX, $2 + JEQ r2_11 // we have 2 remaining elements + CMPQ CX, $3 + JNE loop4by4_8 // we have 0 remaining elements // we have 3 remaining elements - VPMOVZXDQ 2*32(AX), ZMM4 - VPADDQ ZMM4, ZMM0, ZMM0 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z0, Z0 r2_11: // we have 2 remaining elements - VPMOVZXDQ 1*32(AX), ZMM4 - VPADDQ ZMM4, ZMM1, ZMM1 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 r1_10: // we have 1 remaining element - VPMOVZXDQ 0*32(AX), ZMM4 - VPADDQ ZMM4, ZMM2, ZMM2 - TESTQ DX, DX - JEQ done_9 // n == 0, we are done - -loop_8: - // increment pointers to visit next element - ADDQ $128, AX - DECQ DX // decrement n - JMP loop_8 + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + MOVQ $32, DX + IMULQ CX, DX + ADDQ DX, R14 + +loop4by4_8: + TESTQ R15, R15 + JEQ accumulate_12 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z4 + VPADDQ Z4, Z0, Z0 + VPMOVZXDQ 1*32(R14), Z4 + VPADDQ Z4, Z1, Z1 + VPMOVZXDQ 2*32(R14), Z4 + VPADDQ Z4, Z2, Z2 + VPMOVZXDQ 3*32(R14), Z4 + VPADDQ Z4, Z3, Z3 + + // increment pointers to visit next 4 elements + ADDQ $128, R14 + DECQ R15 // decrement n + JMP loop4by4_8 + +accumulate_12: + // accumulate the 4 Z registers into Z0 + VPADDQ Z1, Z0, Z0 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 + XORQ AX, AX // clear the flags + MOVQ SI, R13 + ANDQ $0xffffffff, R13 + SHLQ $32, R13 + SHRQ $32, SI + MOVQ R8, CX + ANDQ $0xffffffff, CX + SHLQ $32, CX + SHRQ $32, R8 + MOVQ R10, R15 + ANDQ $0xffffffff, R15 + SHLQ $32, R15 + SHRQ $32, R10 + MOVQ R12, R14 + ANDQ $0xffffffff, R14 + SHLQ $32, R14 + SHRQ $32, R12 + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ res+0(FP), SI + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + + // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS done_9 + MOVQ BX, 0(SI) + MOVQ DI, 8(SI) + MOVQ R9, 16(SI) + MOVQ R11, 24(SI) done_9: - MOVQ res+0(FP), AX RET diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index 6d0d2ec9e..a3478b0db 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 7 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -749,13 +749,14 @@ func TestElementVecOps(t *testing.T) { for i := 0; i < N; i++ { sum.Add(&sum, &c[i]) } + assert.True(sum.Equal(&computed), "Vector sum failed") } func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index c51c1d789..ce3dfe8bc 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -248,21 +248,45 @@ func (f *FFAmd64) generateSumVec() { f.Comment("sumVec(res, a *Element, n uint64) res = sum(a[0...n])") const argSize = 3 * 8 - stackSize := f.StackSize(f.NbWords*2+4, 0, 0) + stackSize := f.StackSize(f.NbWords*3+2, 0, 0) registers := f.FnHeader("sumVec", stackSize, argSize, amd64.DX, amd64.AX) defer f.AssertCleanStack(stackSize, 0) + f.WriteLn(` + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 + // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements + // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + `) + // registers & labels we need addrA := f.Pop(®isters) - len := f.Pop(®isters) - tmp0 := f.Pop(®isters) + n := f.Pop(®isters) + nMod4 := f.Pop(®isters) - loop := f.NewLabel("loop") + loop := f.NewLabel("loop4by4") done := f.NewLabel("done") - rr1 := f.NewLabel("rr1") - rr2 := f.NewLabel("rr2") + rr1 := f.NewLabel("r1") + rr2 := f.NewLabel("r2") accumulate := f.NewLabel("accumulate") - // propagate := f.NewLabel("propagate") // AVX512 registers Z0 := amd64.Register("Z0") @@ -272,37 +296,37 @@ func (f *FFAmd64) generateSumVec() { Z4 := amd64.Register("Z4") X0 := amd64.Register("X0") - f.XORQ(amd64.AX, amd64.AX) - // load arguments f.MOVQ("a+8(FP)", addrA) - f.MOVQ("n+16(FP)", len) + f.MOVQ("n+16(FP)", n) - // initialize accumulators to zero (zmm0, zmm1, zmm2, zmm3) + f.Comment("initialize accumulators Z0, Z1, Z2, Z3") f.VXORPS(Z0, Z0, Z0) f.VMOVDQA64(Z0, Z1) f.VMOVDQA64(Z0, Z2) f.VMOVDQA64(Z0, Z3) - f.TESTQ(len, len) - f.JEQ(done, "n == 0, we are done") + // note: we don't need to handle the case n==0; handled by caller already. + // f.TESTQ(n, n) + // f.JEQ(done, "n == 0, we are done") - f.MOVQ(len, tmp0) - f.ANDQ("$3", tmp0) // t0 = n % 4 - f.SHRQ("$2", len) // len = n / 4 + f.LabelRegisters("n % 4", nMod4) + f.LabelRegisters("n / 4", n) + f.MOVQ(n, nMod4) + f.ANDQ("$3", nMod4) // t0 = n % 4 + f.SHRQ("$2", n) // len = n / 4 // if len % 4 != 0, we need to handle the remaining elements - f.CMPQ(tmp0, "$1") + f.CMPQ(nMod4, "$1") f.JEQ(rr1, "we have 1 remaining element") - f.CMPQ(tmp0, "$2") + f.CMPQ(nMod4, "$2") f.JEQ(rr2, "we have 2 remaining elements") - f.CMPQ(tmp0, "$3") - f.JNE(loop, "== 0; we have 0 remaining elements") + f.CMPQ(nMod4, "$3") + f.JNE(loop, "we have 0 remaining elements") f.Comment("we have 3 remaining elements") - // vpmovzxdq 2*32(PX), %zmm4; vpaddq %zmm4, %zmm0, %zmm0 f.VPMOVZXDQ("2*32("+addrA+")", Z4) f.VPADDQ(Z4, Z0, Z0) @@ -321,14 +345,14 @@ func (f *FFAmd64) generateSumVec() { // mul $32 by tmp0 // TODO use better instructions f.MOVQ("$32", amd64.DX) - f.IMULQ(tmp0, amd64.DX) + f.IMULQ(nMod4, amd64.DX) f.ADDQ(amd64.DX, addrA) - f.Push(®isters, tmp0) // we don't need tmp0 - tmp0 = "" + f.Push(®isters, nMod4) // we don't need tmp0 + nMod4 = "" f.LABEL(loop) - f.TESTQ(len, len) + f.TESTQ(n, n) f.JEQ(accumulate, "n == 0, we are going to accumulate") f.VPMOVZXDQ("0*32("+addrA+")", Z4) @@ -345,15 +369,16 @@ func (f *FFAmd64) generateSumVec() { f.Comment("increment pointers to visit next 4 elements") f.ADDQ("$128", addrA) - f.DECQ(len, "decrement n") + f.DECQ(n, "decrement n") f.JMP(loop) - f.Push(®isters, len, addrA) // we don't need len - len = "" + f.Push(®isters, n, addrA) // we don't need len + n = "" addrA = "" f.LABEL(accumulate) + f.Comment("accumulate the 4 Z registers into Z0") f.VPADDQ(Z1, Z0, Z0) f.VPADDQ(Z3, Z2, Z2) f.VPADDQ(Z2, Z0, Z0) @@ -372,6 +397,17 @@ func (f *FFAmd64) generateSumVec() { low3h := f.Pop(®isters) // Propagate carries + f.Comment("carry propagation") + + f.LabelRegisters("lo(w0)", w0l) + f.LabelRegisters("hi(w0)", w0h) + f.LabelRegisters("lo(w1)", w1l) + f.LabelRegisters("hi(w1)", w1h) + f.LabelRegisters("lo(w2)", w2l) + f.LabelRegisters("hi(w2)", w2h) + f.LabelRegisters("lo(w3)", w3l) + f.LabelRegisters("hi(w3)", w3h) + f.VMOVQ(X0, w0l) f.VALIGNQ("$1", Z0, Z0, Z0) f.VMOVQ(X0, w0h) @@ -388,13 +424,12 @@ func (f *FFAmd64) generateSumVec() { f.VALIGNQ("$1", Z0, Z0, Z0) f.VMOVQ(X0, w3h) - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) + f.LabelRegisters("lo(hi(wo))", low0h) + f.LabelRegisters("lo(hi(w1))", low1h) + f.LabelRegisters("lo(hi(w2))", low2h) + f.LabelRegisters("lo(hi(w3))", low3h) - // we need 2 carry so we use ADOXQ and ADCXQ - f.XORQ(amd64.AX, amd64.AX) + f.XORQ(amd64.AX, amd64.AX, "clear the flags") type hilo struct { hi, lo amd64.Register } @@ -405,8 +440,14 @@ func (f *FFAmd64) generateSumVec() { f.SHRQ("$32", v.hi) } - f.XORQ(amd64.AX, amd64.AX) - // start the carry chain + f.WriteLn(` + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + `) + f.XORQ(amd64.AX, amd64.AX, "clear the flags") f.ADOXQ(low0h, w0l) f.ADOXQ(low1h, w1l) @@ -426,7 +467,8 @@ func (f *FFAmd64) generateSumVec() { r2 := w2l r3 := w3l r4 := w3h - r := []amd64.Register{r0, r1, r2, r3} + r := []amd64.Register{r0, r1, r2, r3, r4} + f.LabelRegisters("r", r...) // we don't need w0h, w1h, w2h anymore f.Push(®isters, w0h, w1h, w2h) w0h = "" @@ -440,13 +482,10 @@ func (f *FFAmd64) generateSumVec() { low3h = "" // Reduce using single-word Barrett - // q1 is low 32 bits of T4 and high 32 bits of T3 - // movq T3, %rax - // shrd $32, T4, %rax - // mulq MU // Multiply by mu. q2 in rdx:rax, q3 in rdx mu := f.Pop(®isters) - f.XORQ(amd64.AX, amd64.AX) + f.Comment("reduce using single-word Barrett") + f.LabelRegisters("mu=2^288 / q", mu) f.MOVQ(f.mu(), mu) f.MOVQ(r3, amd64.AX) f.SHRQw("$32", r4, amd64.AX) @@ -474,6 +513,7 @@ func (f *FFAmd64) generateSumVec() { f.Mov(r, addrRes) // sub modulus + f.Comment("TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce") f.SUBQ(f.qAt(0), r0) f.SBBQ(f.qAt(1), r1) f.SBBQ(f.qAt(2), r2) diff --git a/field/generator/internal/templates/element/asm.go b/field/generator/internal/templates/element/asm.go index c1027f148..5a73aa53b 100644 --- a/field/generator/internal/templates/element/asm.go +++ b/field/generator/internal/templates/element/asm.go @@ -7,6 +7,10 @@ import "golang.org/x/sys/cpu" var ( supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 _ = supportAdx + {{- if eq .NbWords 4}} + supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ + _ = supportAvx512 + {{- end}} ) ` @@ -19,5 +23,9 @@ const AsmNoAdx = ` var ( supportAdx = false _ = supportAdx + {{- if eq .NbWords 4}} + supportAvx512 = false + _ = supportAvx512 + {{- end}} ) ` diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index d2aa9c887..061e95d9a 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -68,7 +68,12 @@ func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) // Sum computes the sum of all elements in the vector. func (vector *Vector) Sum() (res {{.ElementName}}) { - if len(*vector) == 0 { + n := uint64(len(*vector)) + const minN = 16*7 // AVX512 slower than generic for small n + const maxN = (1 << 32) + 1 + if !supportAvx512 || n <= minN || n >= maxN { + // call sumVecGeneric + sumVecGeneric(&res, *vector) return } sumVec(&res, &(*vector)[0], uint64(len(*vector))) diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 5d9d23a75..b8ba75b58 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 4*16 + 4 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -777,7 +777,7 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 * 10 + const N = 1<<20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 64b1dddd8..f4d6606ff 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 4*16 + 4 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) @@ -701,7 +701,7 @@ func TestElementVecOps(t *testing.T) { func BenchmarkElementVecOps(b *testing.B) { // note; to benchmark against "no asm" version, use the following // build tag: -tags purego - const N = 1024 * 10 + const N = 1 << 20 a1 := make(Vector, N) b1 := make(Vector, N) c1 := make(Vector, N) diff --git a/go.mod b/go.mod index d2ac46944..dd93d660a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/bits-and-blooms/bitset v1.14.2 - github.com/consensys/bavard v0.0.0 + github.com/consensys/bavard v0.1.16 github.com/leanovate/gopter v0.2.11 github.com/mmcloughlin/addchain v0.4.0 github.com/spf13/cobra v1.8.1 @@ -14,8 +14,6 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) -replace github.com/consensys/bavard => ../bavard - require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 2c324c00b..3d91ce1bf 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.15 h1:fxv2mg1afRMJvZgpwEgLmyr2MsQwaAYcyKf31UBHzw4= -github.com/consensys/bavard v0.1.15/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/bavard v0.1.16 h1:3+nT0BqzKg84VtCY9eNN2Glkf1X7dbS5yhh5849syJ8= +github.com/consensys/bavard v0.1.16/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= diff --git a/internal/generator/main.go b/internal/generator/main.go index ad3a46039..389f96c2e 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -46,9 +46,6 @@ func main() { var wg sync.WaitGroup for _, conf := range config.Curves { - if !conf.Equal(config.BLS12_377) { - continue - } wg.Add(1) // for each curve, generate the needed files go func(conf config.Curve) { From 4796eb36c9a3f74247b9f9dbe850c5d988edd3e9 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 22 Sep 2024 02:33:17 +0000 Subject: [PATCH 06/14] test: make odd bound for better test case --- ecc/bls12-377/fp/element_test.go | 2 +- ecc/bls12-377/fr/element_test.go | 2 +- ecc/bls12-381/fp/element_test.go | 2 +- ecc/bls12-381/fr/element_test.go | 2 +- ecc/bls24-315/fp/element_test.go | 2 +- ecc/bls24-315/fr/element_test.go | 2 +- ecc/bls24-317/fp/element_test.go | 2 +- ecc/bls24-317/fr/element_test.go | 2 +- ecc/bn254/fp/element_test.go | 2 +- ecc/bn254/fr/element_test.go | 2 +- ecc/bw6-633/fp/element_test.go | 2 +- ecc/bw6-633/fr/element_test.go | 2 +- ecc/bw6-761/fp/element_test.go | 2 +- ecc/bw6-761/fr/element_test.go | 2 +- ecc/secp256k1/fp/element_test.go | 2 +- ecc/secp256k1/fr/element_test.go | 2 +- ecc/stark-curve/fp/element_test.go | 2 +- ecc/stark-curve/fr/element_test.go | 2 +- field/generator/internal/templates/element/tests.go | 2 +- field/goldilocks/element_test.go | 2 +- 20 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 8376c41e2..a9e630152 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index b576a09f4..265719d23 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index 88974f2cd..6e51a0915 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index ff0b3c188..2005553e1 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index bdf5071ad..0e397c4c7 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 007beea50..ca15b880d 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 9a20746bc..1ef36f59f 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 881f55ff3..7e887d602 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 14847a533..e6250c4c8 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 8c8e431dc..2ace9113a 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 73b3146b6..c332c558f 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -722,7 +722,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index a4b82539b..d805a78a6 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index d1f404566..e5cddf1ab 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -726,7 +726,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 90d162aee..c483e9d15 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 4cedc9a12..4a03b4c75 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index f97f76998..3a92c4c76 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index c51d1708a..3f82bcd7c 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index a3478b0db..94be0c490 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index b8ba75b58..85eafacaf 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index f4d6606ff..c8cf99312 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 4 + const N = 1024*16 + 3 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) From dfcb1101070c2308d676b6e025a5e4c9e350a203 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 22 Sep 2024 02:37:04 +0000 Subject: [PATCH 07/14] build: make linter happy --- ecc/bls12-377/fr/element_ops_amd64.s | 2 +- ecc/bls12-381/fr/element_ops_amd64.s | 2 +- ecc/bls24-315/fr/element_ops_amd64.s | 2 +- ecc/bls24-317/fr/element_ops_amd64.s | 2 +- ecc/bn254/fp/element_ops_amd64.s | 2 +- ecc/bn254/fr/element_ops_amd64.s | 2 +- ecc/stark-curve/fp/element_ops_amd64.s | 2 +- ecc/stark-curve/fr/element_ops_amd64.s | 2 +- field/generator/asm/amd64/element_vec.go | 12 +----------- 9 files changed, 9 insertions(+), 19 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 0fdfd2104..29acbbfc6 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index 034d36141..fc3206512 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 9a05cb98f..74ecdfcfd 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index ead13527b..e9c2b90c3 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 87880fe50..838847743 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 49581b69e..1ae6c2c48 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index d65b7f196..f9d6112d9 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 014740d05..02feed270 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -810,7 +810,7 @@ accumulate_12: MOVQ R9, 16(SI) MOVQ R11, 24(SI) - // TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce + // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce SUBQ q<>+0(SB), BX SBBQ q<>+8(SB), DI SBBQ q<>+16(SB), R9 diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index ce3dfe8bc..c6eee5e14 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -349,7 +349,6 @@ func (f *FFAmd64) generateSumVec() { f.ADDQ(amd64.DX, addrA) f.Push(®isters, nMod4) // we don't need tmp0 - nMod4 = "" f.LABEL(loop) f.TESTQ(n, n) @@ -373,8 +372,6 @@ func (f *FFAmd64) generateSumVec() { f.JMP(loop) f.Push(®isters, n, addrA) // we don't need len - n = "" - addrA = "" f.LABEL(accumulate) @@ -471,15 +468,8 @@ func (f *FFAmd64) generateSumVec() { f.LabelRegisters("r", r...) // we don't need w0h, w1h, w2h anymore f.Push(®isters, w0h, w1h, w2h) - w0h = "" - w1h = "" - w2h = "" // we don't need the low bits anymore f.Push(®isters, low0h, low1h, low2h, low3h) - low0h = "" - low1h = "" - low2h = "" - low3h = "" // Reduce using single-word Barrett mu := f.Pop(®isters) @@ -513,7 +503,7 @@ func (f *FFAmd64) generateSumVec() { f.Mov(r, addrRes) // sub modulus - f.Comment("TODO @gbotrel check if 2 conditional substracts is guaranteed to be suffisant for mod reduce") + f.Comment("TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce") f.SUBQ(f.qAt(0), r0) f.SBBQ(f.qAt(1), r1) f.SBBQ(f.qAt(2), r2) From f45aeb99036caebbff80d82119b7b1035214f6b6 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 22 Sep 2024 11:22:06 -0500 Subject: [PATCH 08/14] fix: update bound for vec sum to match parameter choices --- ecc/bls12-377/fr/element_ops_amd64.go | 2 +- ecc/bls12-377/fr/element_ops_amd64.s | 3 +++ ecc/bls12-381/fr/element_ops_amd64.go | 2 +- ecc/bls12-381/fr/element_ops_amd64.s | 3 +++ ecc/bls24-315/fr/element_ops_amd64.go | 2 +- ecc/bls24-315/fr/element_ops_amd64.s | 3 +++ ecc/bls24-317/fr/element_ops_amd64.go | 2 +- ecc/bls24-317/fr/element_ops_amd64.s | 3 +++ ecc/bn254/fp/element_ops_amd64.go | 2 +- ecc/bn254/fp/element_ops_amd64.s | 3 +++ ecc/bn254/fr/element_ops_amd64.go | 2 +- ecc/bn254/fr/element_ops_amd64.s | 3 +++ ecc/stark-curve/fp/element_ops_amd64.go | 2 +- ecc/stark-curve/fp/element_ops_amd64.s | 3 +++ ecc/stark-curve/fr/element_ops_amd64.go | 2 +- ecc/stark-curve/fr/element_ops_amd64.s | 3 +++ field/generator/asm/amd64/element_vec.go | 4 ++++ field/generator/internal/templates/element/ops_asm.go | 2 +- 18 files changed, 37 insertions(+), 9 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 29acbbfc6..e8892c821 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index fc3206512..c979c96f6 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 74ecdfcfd..e2e5948d9 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index e9c2b90c3..8fe24fb1f 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index 3963ace9d..328d9c4ab 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 838847743..4cf037b15 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 1ae6c2c48..652a82c9c 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index 3963ace9d..328d9c4ab 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index f9d6112d9..6415f5e4e 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index fe4606378..49da23450 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -85,7 +85,7 @@ func scalarMulVec(res, a, b *Element, n uint64) func (vector *Vector) Sum() (res Element) { n := uint64(len(*vector)) const minN = 16 * 7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 02feed270..3d6711bfc 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -654,6 +654,9 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index c6eee5e14..dc968c42e 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -275,6 +275,9 @@ func (f *FFAmd64) generateSumVec() { // r3 = carry + hi(w2h) + w3l + lo(w3h) // r4 = carry + hi(w3h) // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) `) // registers & labels we need @@ -464,6 +467,7 @@ func (f *FFAmd64) generateSumVec() { r2 := w2l r3 := w3l r4 := w3h + r := []amd64.Register{r0, r1, r2, r3, r4} f.LabelRegisters("r", r...) // we don't need w0h, w1h, w2h anymore diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index 061e95d9a..c068d7125 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -70,7 +70,7 @@ func scalarMulVec(res, a, b *{{.ElementName}}, n uint64) func (vector *Vector) Sum() (res {{.ElementName}}) { n := uint64(len(*vector)) const minN = 16*7 // AVX512 slower than generic for small n - const maxN = (1 << 32) + 1 + const maxN = (1 << 32) - 1 if !supportAvx512 || n <= minN || n >= maxN { // call sumVecGeneric sumVecGeneric(&res, *vector) From 75120a095777641f3da8c299c2e405d1fcf01fcb Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 23 Sep 2024 15:17:47 +0000 Subject: [PATCH 09/14] perf: loop 8 by 8, cosmetics --- ecc/bls12-377/fp/element_test.go | 2 +- ecc/bls12-377/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/bls12-377/fr/element_test.go | 2 +- ecc/bls12-381/fp/element_test.go | 2 +- ecc/bls12-381/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/bls12-381/fr/element_test.go | 2 +- ecc/bls24-315/fp/element_test.go | 2 +- ecc/bls24-315/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/bls24-315/fr/element_test.go | 2 +- ecc/bls24-317/fp/element_test.go | 2 +- ecc/bls24-317/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/bls24-317/fr/element_test.go | 2 +- ecc/bn254/fp/element_ops_amd64.s | 134 +++++++++--------- ecc/bn254/fp/element_test.go | 2 +- ecc/bn254/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/bn254/fr/element_test.go | 2 +- ecc/bw6-633/fp/element_test.go | 2 +- ecc/bw6-633/fr/element_test.go | 2 +- ecc/bw6-761/fp/element_test.go | 2 +- ecc/bw6-761/fr/element_test.go | 2 +- ecc/secp256k1/fp/element_test.go | 2 +- ecc/secp256k1/fr/element_test.go | 2 +- ecc/stark-curve/fp/element_ops_amd64.s | 134 +++++++++--------- ecc/stark-curve/fp/element_test.go | 2 +- ecc/stark-curve/fr/element_ops_amd64.s | 134 +++++++++--------- ecc/stark-curve/fr/element_test.go | 2 +- field/generator/asm/amd64/element_vec.go | 125 ++++++++-------- .../internal/templates/element/tests.go | 2 +- field/goldilocks/element_test.go | 2 +- 29 files changed, 600 insertions(+), 637 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index a9e630152..8376c41e2 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index e8892c821..49c4e31aa 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 265719d23..b576a09f4 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index 6e51a0915..88974f2cd 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index c979c96f6..e58114b96 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 2005553e1..ff0b3c188 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 0e397c4c7..bdf5071ad 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index e2e5948d9..0061bbefe 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index ca15b880d..007beea50 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 1ef36f59f..9a20746bc 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index 8fe24fb1f..430db1b74 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 7e887d602..881f55ff3 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 4cf037b15..a841afce0 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index e6250c4c8..14847a533 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 652a82c9c..8e87e0dfa 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 2ace9113a..8c8e431dc 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index c332c558f..73b3146b6 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -722,7 +722,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index d805a78a6..a4b82539b 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -712,7 +712,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index e5cddf1ab..d1f404566 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -726,7 +726,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index c483e9d15..90d162aee 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -714,7 +714,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 4a03b4c75..4cedc9a12 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index 3a92c4c76..f97f76998 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -708,7 +708,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 6415f5e4e..fa703f31d 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 3f82bcd7c..c51d1708a 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 3d6711bfc..1c21e37bd 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -633,9 +633,9 @@ noAdx_5: TEXT ·sumVec(SB), NOSPLIT, $0-24 // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -661,63 +661,66 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 MOVQ a+8(FP), R14 MOVQ n+16(FP), R15 - // initialize accumulators Z0, Z1, Z2, Z3 + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 VXORPS Z0, Z0, Z0 VMOVDQA64 Z0, Z1 VMOVDQA64 Z0, Z2 VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 - // n % 4 -> CX - // n / 4 -> R15 + // n % 8 -> CX + // n / 8 -> R15 MOVQ R15, CX - ANDQ $3, CX - SHRQ $2, R15 - CMPQ CX, $1 - JEQ r1_10 // we have 1 remaining element - CMPQ CX, $2 - JEQ r2_11 // we have 2 remaining elements - CMPQ CX, $3 - JNE loop4by4_8 // we have 0 remaining elements - - // we have 3 remaining elements - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - -r2_11: - // we have 2 remaining elements - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - -r1_10: - // we have 1 remaining element - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - MOVQ $32, DX - IMULQ CX, DX - ADDQ DX, R14 - -loop4by4_8: - TESTQ R15, R15 - JEQ accumulate_12 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z4 - VPADDQ Z4, Z0, Z0 - VPMOVZXDQ 1*32(R14), Z4 - VPADDQ Z4, Z1, Z1 - VPMOVZXDQ 2*32(R14), Z4 - VPADDQ Z4, Z2, Z2 - VPMOVZXDQ 3*32(R14), Z4 - VPADDQ Z4, Z3, Z3 - - // increment pointers to visit next 4 elements - ADDQ $128, R14 + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 DECQ R15 // decrement n - JMP loop4by4_8 - -accumulate_12: - // accumulate the 4 Z registers into Z0 - VPADDQ Z1, Z0, Z0 + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z0, Z0 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 // carry propagation // lo(w0) -> BX @@ -748,23 +751,16 @@ accumulate_12: // lo(hi(w1)) -> CX // lo(hi(w2)) -> R15 // lo(hi(w3)) -> R14 - XORQ AX, AX // clear the flags - MOVQ SI, R13 - ANDQ $0xffffffff, R13 - SHLQ $32, R13 - SHRQ $32, SI - MOVQ R8, CX - ANDQ $0xffffffff, CX - SHLQ $32, CX - SHRQ $32, R8 - MOVQ R10, R15 - ANDQ $0xffffffff, R15 - SHLQ $32, R15 - SHRQ $32, R10 - MOVQ R12, R14 - ANDQ $0xffffffff, R14 - SHLQ $32, R14 - SHRQ $32, R12 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) // r0 = w0l + lo(woh) // r1 = carry + hi(woh) + w1l + lo(w1h) diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index 94be0c490..a3478b0db 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -710,7 +710,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index dc968c42e..07d7078a9 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -15,6 +15,8 @@ package amd64 import ( + "fmt" + "github.com/consensys/bavard/amd64" ) @@ -254,9 +256,9 @@ func (f *FFAmd64) generateSumVec() { f.WriteLn(` // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 4 by 4 - // first, we handle the case where n % 4 != 0 and add to the accumulators the 1, 2 or 3 remaining elements - // then, we loop over the elements 4 by 4 and accumulate the sum in the Z registers + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers // finally, we reduce the sum and store it in res // // when we move an element of a into a Z register, we use VPMOVZXDQ @@ -283,12 +285,11 @@ func (f *FFAmd64) generateSumVec() { // registers & labels we need addrA := f.Pop(®isters) n := f.Pop(®isters) - nMod4 := f.Pop(®isters) + nMod8 := f.Pop(®isters) - loop := f.NewLabel("loop4by4") + loop := f.NewLabel("loop8by8") done := f.NewLabel("done") - rr1 := f.NewLabel("r1") - rr2 := f.NewLabel("r2") + loopSingle := f.NewLabel("loop_single") accumulate := f.NewLabel("accumulate") // AVX512 registers @@ -297,91 +298,81 @@ func (f *FFAmd64) generateSumVec() { Z2 := amd64.Register("Z2") Z3 := amd64.Register("Z3") Z4 := amd64.Register("Z4") + Z5 := amd64.Register("Z5") + Z6 := amd64.Register("Z6") + Z7 := amd64.Register("Z7") + Z8 := amd64.Register("Z8") + X0 := amd64.Register("X0") // load arguments f.MOVQ("a+8(FP)", addrA) f.MOVQ("n+16(FP)", n) - f.Comment("initialize accumulators Z0, Z1, Z2, Z3") + f.Comment("initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7") f.VXORPS(Z0, Z0, Z0) f.VMOVDQA64(Z0, Z1) f.VMOVDQA64(Z0, Z2) f.VMOVDQA64(Z0, Z3) + f.VMOVDQA64(Z0, Z4) + f.VMOVDQA64(Z0, Z5) + f.VMOVDQA64(Z0, Z6) + f.VMOVDQA64(Z0, Z7) // note: we don't need to handle the case n==0; handled by caller already. // f.TESTQ(n, n) // f.JEQ(done, "n == 0, we are done") - f.LabelRegisters("n % 4", nMod4) - f.LabelRegisters("n / 4", n) - f.MOVQ(n, nMod4) - f.ANDQ("$3", nMod4) // t0 = n % 4 - f.SHRQ("$2", n) // len = n / 4 - - // if len % 4 != 0, we need to handle the remaining elements - f.CMPQ(nMod4, "$1") - f.JEQ(rr1, "we have 1 remaining element") - - f.CMPQ(nMod4, "$2") - f.JEQ(rr2, "we have 2 remaining elements") - - f.CMPQ(nMod4, "$3") - f.JNE(loop, "we have 0 remaining elements") + f.LabelRegisters("n % 8", nMod8) + f.LabelRegisters("n / 8", n) + f.MOVQ(n, nMod8) + f.ANDQ("$7", nMod8) // nMod8 = n % 8 + f.SHRQ("$3", n) // len = n / 8 - f.Comment("we have 3 remaining elements") - f.VPMOVZXDQ("2*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z0, Z0) + f.LABEL(loopSingle) + f.TESTQ(nMod8, nMod8) + f.JEQ(loop, "n % 8 == 0, we are going to loop over 8 by 8") - f.LABEL(rr2) - f.Comment("we have 2 remaining elements") - // vpmovzxdq 1*32(PX), %zmm4; vpaddq %zmm4, %zmm1, %zmm1 - f.VPMOVZXDQ("1*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z1, Z1) - - f.LABEL(rr1) - f.Comment("we have 1 remaining element") - // vpmovzxdq 0*32(PX), %zmm4; vpaddq %zmm4, %zmm2, %zmm2 - f.VPMOVZXDQ("0*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z2, Z2) + f.VPMOVZXDQ("0("+addrA+")", Z8) + f.VPADDQ(Z8, Z0, Z0) + f.ADDQ("$32", addrA) - // mul $32 by tmp0 - // TODO use better instructions - f.MOVQ("$32", amd64.DX) - f.IMULQ(nMod4, amd64.DX) - f.ADDQ(amd64.DX, addrA) + f.DECQ(nMod8, "decrement nMod8") + f.JMP(loopSingle) - f.Push(®isters, nMod4) // we don't need tmp0 + f.Push(®isters, nMod8) // we don't need tmp0 f.LABEL(loop) f.TESTQ(n, n) f.JEQ(accumulate, "n == 0, we are going to accumulate") - f.VPMOVZXDQ("0*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z0, Z0) - - f.VPMOVZXDQ("1*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z1, Z1) - - f.VPMOVZXDQ("2*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z2, Z2) - - f.VPMOVZXDQ("3*32("+addrA+")", Z4) - f.VPADDQ(Z4, Z3, Z3) + for i := 0; i < 8; i++ { + r := fmt.Sprintf("Z%d", i+8) + f.VPMOVZXDQ(fmt.Sprintf("%d*32("+string(addrA)+")", i), r) + } + f.WriteLn(fmt.Sprintf("PREFETCHT0 256(%[1]s)", addrA)) + for i := 0; i < 8; i++ { + r := fmt.Sprintf("Z%d", i) + f.VPADDQ(fmt.Sprintf("Z%d", i+8), r, r) + } - f.Comment("increment pointers to visit next 4 elements") - f.ADDQ("$128", addrA) + f.Comment("increment pointers to visit next 8 elements") + f.ADDQ("$256", addrA) f.DECQ(n, "decrement n") f.JMP(loop) - f.Push(®isters, n, addrA) // we don't need len + f.Push(®isters, n, addrA) f.LABEL(accumulate) - f.Comment("accumulate the 4 Z registers into Z0") - f.VPADDQ(Z1, Z0, Z0) + f.Comment("accumulate the 8 Z registers into Z0") + f.VPADDQ(Z7, Z6, Z6) + f.VPADDQ(Z6, Z5, Z5) + f.VPADDQ(Z5, Z4, Z4) + f.VPADDQ(Z4, Z3, Z3) f.VPADDQ(Z3, Z2, Z2) - f.VPADDQ(Z2, Z0, Z0) + f.VPADDQ(Z2, Z1, Z1) + f.VPADDQ(Z1, Z0, Z0) w0l := f.Pop(®isters) w0h := f.Pop(®isters) @@ -429,15 +420,19 @@ func (f *FFAmd64) generateSumVec() { f.LabelRegisters("lo(hi(w2))", low2h) f.LabelRegisters("lo(hi(w3))", low3h) - f.XORQ(amd64.AX, amd64.AX, "clear the flags") type hilo struct { hi, lo amd64.Register } + + f.WriteLn(`#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + `) + for _, v := range []hilo{{w0h, low0h}, {w1h, low1h}, {w2h, low2h}, {w3h, low3h}} { - f.MOVQ(v.hi, v.lo) - f.ANDQ("$0xffffffff", v.lo) - f.SHLQ("$32", v.lo) - f.SHRQ("$32", v.hi) + f.WriteLn(`SPLIT_LO_HI(` + string(v.lo) + `, ` + string(v.hi) + `)`) } f.WriteLn(` diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 85eafacaf..f07e4cd0b 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -731,7 +731,7 @@ func Test{{toTitle .ElementName}}LexicographicallyLargest(t *testing.T) { func Test{{toTitle .ElementName}}VecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 +4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index c8cf99312..f4d6606ff 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -655,7 +655,7 @@ func TestElementLexicographicallyLargest(t *testing.T) { func TestElementVecOps(t *testing.T) { assert := require.New(t) - const N = 1024*16 + 3 + const N = 1024*16 + 4 a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) From dc15a6a6c091750fd16518e71584bb1b4f431b07 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 23 Sep 2024 18:46:05 +0000 Subject: [PATCH 10/14] style: cosmetics --- ecc/bls12-377/fr/element_ops_amd64.s | 60 +++++++++++++----------- ecc/bls12-381/fr/element_ops_amd64.s | 60 +++++++++++++----------- ecc/bls24-315/fr/element_ops_amd64.s | 60 +++++++++++++----------- ecc/bls24-317/fr/element_ops_amd64.s | 60 +++++++++++++----------- ecc/bn254/fp/element_ops_amd64.s | 60 +++++++++++++----------- ecc/bn254/fr/element_ops_amd64.s | 60 +++++++++++++----------- ecc/stark-curve/fp/element_ops_amd64.s | 60 +++++++++++++----------- ecc/stark-curve/fr/element_ops_amd64.s | 60 +++++++++++++----------- field/generator/asm/amd64/element_vec.go | 25 ++++++---- 9 files changed, 271 insertions(+), 234 deletions(-) diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 49c4e31aa..54a104040 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index e58114b96..16bc87fbf 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 0061bbefe..45756951f 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index 430db1b74..16c632592 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index a841afce0..44c900aad 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 8e87e0dfa..811fa1397 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index fa703f31d..98bd7fa01 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index 1c21e37bd..e0149bd6c 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -789,7 +789,7 @@ accumulate_11: MOVQ mu<>(SB), SI MOVQ R11, AX SHRQ $32, R12, AX - MULQ SI + MULQ SI // high bits of res stored in DX MULXQ q<>+0(SB), AX, SI SUBQ AX, BX SBBQ SI, DI @@ -803,33 +803,37 @@ accumulate_11: MULXQ q<>+24(SB), AX, SI SBBQ AX, R11 SBBQ SI, R12 - MOVQ res+0(FP), SI - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - - // TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS done_9 - MOVQ BX, 0(SI) - MOVQ DI, 8(SI) - MOVQ R9, 16(SI) - MOVQ R11, 24(SI) + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) done_9: RET diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index 07d7078a9..a4ea06144 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -478,7 +478,7 @@ func (f *FFAmd64) generateSumVec() { f.MOVQ(f.mu(), mu) f.MOVQ(r3, amd64.AX) f.SHRQw("$32", r4, amd64.AX) - f.MULQ(mu) + f.MULQ(mu, "high bits of res stored in DX") f.MULXQ(f.qAt(0), amd64.AX, mu) f.SUBQ(amd64.AX, r0) @@ -497,21 +497,21 @@ func (f *FFAmd64) generateSumVec() { f.SBBQ(amd64.AX, r3) f.SBBQ(mu, r4) - addrRes := mu - f.MOVQ("res+0(FP)", addrRes) - f.Mov(r, addrRes) + // we need up to 2 conditional substractions to be < q + modReduced := f.NewLabel("modReduced") + t := f.PopN(®isters) + f.Mov(r[:4], t) // backup r0 to r3 (our result) // sub modulus - f.Comment("TODO @gbotrel check if 2 conditional subtracts is guaranteed to be suffisant for mod reduce") f.SUBQ(f.qAt(0), r0) f.SBBQ(f.qAt(1), r1) f.SBBQ(f.qAt(2), r2) f.SBBQ(f.qAt(3), r3) f.SBBQ("$0", r4) - // if borrow, we skip to the end - f.JCS(done) - f.Mov(r, addrRes) + // if borrow, we go to mod reduced + f.JCS(modReduced) + f.Mov(r, t) f.SUBQ(f.qAt(0), r0) f.SBBQ(f.qAt(1), r1) f.SBBQ(f.qAt(2), r2) @@ -519,8 +519,13 @@ func (f *FFAmd64) generateSumVec() { f.SBBQ("$0", r4) // if borrow, we skip to the end - f.JCS(done) - f.Mov(r, addrRes) + f.JCS(modReduced) + f.Mov(r, t) + + f.LABEL(modReduced) + addrRes := mu + f.MOVQ("res+0(FP)", addrRes) + f.Mov(t, addrRes) f.LABEL(done) From a66f5479760ae9aafb0ad2d3ae9da67bfce1c694 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 23 Sep 2024 18:52:43 +0000 Subject: [PATCH 11/14] test: better sum test --- ecc/bls12-377/fp/element_test.go | 15 +++++++++------ ecc/bls12-377/fr/element_test.go | 15 +++++++++------ ecc/bls12-381/fp/element_test.go | 15 +++++++++------ ecc/bls12-381/fr/element_test.go | 15 +++++++++------ ecc/bls24-315/fp/element_test.go | 15 +++++++++------ ecc/bls24-315/fr/element_test.go | 15 +++++++++------ ecc/bls24-317/fp/element_test.go | 15 +++++++++------ ecc/bls24-317/fr/element_test.go | 15 +++++++++------ ecc/bn254/fp/element_test.go | 15 +++++++++------ ecc/bn254/fr/element_test.go | 15 +++++++++------ ecc/bw6-633/fp/element_test.go | 15 +++++++++------ ecc/bw6-633/fr/element_test.go | 15 +++++++++------ ecc/bw6-761/fp/element_test.go | 15 +++++++++------ ecc/bw6-761/fr/element_test.go | 15 +++++++++------ ecc/secp256k1/fp/element_test.go | 15 +++++++++------ ecc/secp256k1/fr/element_test.go | 15 +++++++++------ ecc/stark-curve/fp/element_test.go | 15 +++++++++------ ecc/stark-curve/fr/element_test.go | 15 +++++++++------ .../generator/internal/templates/element/tests.go | 15 +++++++++------ field/goldilocks/element_test.go | 15 +++++++++------ 20 files changed, 180 insertions(+), 120 deletions(-) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 8376c41e2..cdb4a16ce 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -748,13 +748,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index b576a09f4..1f61cd5fd 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index 88974f2cd..f24087938 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -748,13 +748,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index ff0b3c188..e8962dce1 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index bdf5071ad..6c0db3afb 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -746,13 +746,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 007beea50..5324e7284 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 9a20746bc..4659f2aca 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -746,13 +746,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 881f55ff3..cdff9acc0 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 14847a533..7cfc6727d 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 8c8e431dc..7fabe0443 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 73b3146b6..08c89c058 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -756,13 +756,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index a4b82539b..e74315101 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -746,13 +746,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index d1f404566..2a8f265cc 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -760,13 +760,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 90d162aee..30696a100 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -748,13 +748,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 4cedc9a12..831cd3efe 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -742,13 +742,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index f97f76998..d47e068b0 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -742,13 +742,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index c51d1708a..8efeea221 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index a3478b0db..ceae12aac 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -744,13 +744,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index f07e4cd0b..603b4a8f7 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -765,13 +765,16 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { } // Vector sum - var sum {{.ElementName}} - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum {{.ElementName}} + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func Benchmark{{toTitle .ElementName}}VecOps(b *testing.B) { diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index f4d6606ff..20b183ded 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -689,13 +689,16 @@ func TestElementVecOps(t *testing.T) { } // Vector sum - var sum Element - computed := c.Sum() - for i := 0; i < N; i++ { - sum.Add(&sum, &c[i]) - } + for i := 0; i < N/2; i++ { + subVec := c[:i] + var sum Element + computed := subVec.Sum() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } - assert.True(sum.Equal(&computed), "Vector sum failed") + assert.True(sum.Equal(&computed), "Vector sum failed") + } } func BenchmarkElementVecOps(b *testing.B) { From 5b2b11d23aabd29c75e6ab34530b50901975a9f0 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 23 Sep 2024 19:23:00 +0000 Subject: [PATCH 12/14] test: more test --- ecc/bls12-377/fp/element_test.go | 40 +++++++++++++++++++ ecc/bls12-377/fr/element_test.go | 40 +++++++++++++++++++ ecc/bls12-381/fp/element_test.go | 40 +++++++++++++++++++ ecc/bls12-381/fr/element_test.go | 40 +++++++++++++++++++ ecc/bls24-315/fp/element_test.go | 40 +++++++++++++++++++ ecc/bls24-315/fr/element_test.go | 40 +++++++++++++++++++ ecc/bls24-317/fp/element_test.go | 40 +++++++++++++++++++ ecc/bls24-317/fr/element_test.go | 40 +++++++++++++++++++ ecc/bn254/fp/element_test.go | 40 +++++++++++++++++++ ecc/bn254/fr/element_test.go | 40 +++++++++++++++++++ ecc/bw6-633/fp/element_test.go | 40 +++++++++++++++++++ ecc/bw6-633/fr/element_test.go | 40 +++++++++++++++++++ ecc/bw6-761/fp/element_test.go | 40 +++++++++++++++++++ ecc/bw6-761/fr/element_test.go | 40 +++++++++++++++++++ ecc/secp256k1/fp/element_test.go | 40 +++++++++++++++++++ ecc/secp256k1/fr/element_test.go | 40 +++++++++++++++++++ ecc/stark-curve/fp/element_test.go | 40 +++++++++++++++++++ ecc/stark-curve/fr/element_test.go | 40 +++++++++++++++++++ .../internal/templates/element/tests.go | 40 +++++++++++++++++++ field/goldilocks/element_test.go | 40 +++++++++++++++++++ 20 files changed, 800 insertions(+) diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index cdb4a16ce..a1be8c091 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -718,9 +718,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -730,6 +743,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -738,6 +757,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -746,6 +771,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -757,6 +788,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 1f61cd5fd..4e3d89236 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index f24087938..cb26b5425 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -718,9 +718,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -730,6 +743,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -738,6 +757,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -746,6 +771,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -757,6 +788,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index e8962dce1..68d6739f5 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 6c0db3afb..c48910446 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -716,9 +716,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -728,6 +741,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -736,6 +755,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -744,6 +769,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -755,6 +786,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 5324e7284..e382132bd 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 4659f2aca..da5ab0e0b 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -716,9 +716,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -728,6 +741,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -736,6 +755,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -744,6 +769,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -755,6 +786,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index cdff9acc0..84f8af2df 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 7cfc6727d..8c314b633 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 7fabe0443..2fb459444 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 08c89c058..35b897fa2 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -726,9 +726,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -738,6 +751,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -746,6 +765,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -754,6 +779,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -765,6 +796,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index e74315101..19c176ae3 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -716,9 +716,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -728,6 +741,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -736,6 +755,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -744,6 +769,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -755,6 +786,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index 2a8f265cc..07ea72041 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -730,9 +730,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -742,6 +755,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -750,6 +769,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -758,6 +783,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -769,6 +800,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 30696a100..35f8f44da 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -718,9 +718,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -730,6 +743,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -738,6 +757,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -746,6 +771,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -757,6 +788,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index 831cd3efe..a8672f7e1 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -712,9 +712,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -724,6 +737,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -732,6 +751,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -740,6 +765,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -751,6 +782,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index d47e068b0..c3f0d3ac2 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -712,9 +712,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -724,6 +737,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -732,6 +751,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -740,6 +765,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -751,6 +782,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 8efeea221..5541a687d 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index ceae12aac..88dc47337 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -714,9 +714,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -726,6 +739,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -734,6 +753,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -742,6 +767,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -753,6 +784,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index 603b4a8f7..f31528489 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -735,9 +735,22 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 {{.ElementName}} + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -747,6 +760,12 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -755,6 +774,12 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -763,6 +788,12 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected {{.ElementName}} + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -774,6 +805,15 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 20b183ded..1746bde1c 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -659,9 +659,22 @@ func TestElementVecOps(t *testing.T) { a := make(Vector, N) b := make(Vector, N) c := make(Vector, N) + m := make(Vector, N) + + // set m to max values element + // it's not really q-1 (since we have montgomery representation) + // but it's the "largest" legal value + qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) + + var eQMinus1 Element + for i, v := range qMinus1.Bits() { + eQMinus1[i] = uint64(v) + } + for i := 0; i < N; i++ { a[i].SetRandom() b[i].SetRandom() + m[i] = eQMinus1 } // Vector addition @@ -671,6 +684,12 @@ func TestElementVecOps(t *testing.T) { expected.Add(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector addition failed") } + c.Add(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Add(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector addition failed") + } // Vector subtraction c.Sub(a, b) @@ -679,6 +698,12 @@ func TestElementVecOps(t *testing.T) { expected.Sub(&a[i], &b[i]) assert.True(c[i].Equal(&expected), "Vector subtraction failed") } + c.Sub(a, m) + for i := 0; i < N; i++ { + var expected Element + expected.Sub(&a[i], &m[i]) + assert.True(c[i].Equal(&expected), "Vector subtraction failed") + } // Vector scaling c.ScalarMul(a, &b[0]) @@ -687,6 +712,12 @@ func TestElementVecOps(t *testing.T) { expected.Mul(&a[i], &b[0]) assert.True(c[i].Equal(&expected), "Vector scaling failed") } + c.ScalarMul(m, &b[0]) + for i := 0; i < N; i++ { + var expected Element + expected.Mul(&m[i], &b[0]) + assert.True(c[i].Equal(&expected), "Vector scaling failed") + } // Vector sum for i := 0; i < N/2; i++ { @@ -698,6 +729,15 @@ func TestElementVecOps(t *testing.T) { } assert.True(sum.Equal(&computed), "Vector sum failed") + + subVec = m[:i] + computed = subVec.Sum() + sum.SetZero() + for j := 0; j < len(subVec); j++ { + sum.Add(&sum, &subVec[j]) + } + + assert.True(sum.Equal(&computed), "Vector sum failed") } } From fcfaa059798516f6b91926c930c5fe60112f7324 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 23 Sep 2024 21:24:27 -0500 Subject: [PATCH 13/14] refactor: move common assembly routine in subfolder (#545) * refactor: move common assembly routines in root * build: make linter happier * style: cosmetics * test: start fixing integration test * style: factorize mul documentation * feat: add .ASMVector and fix integartion test * test: fix 32bit test * test: fix previous commit --- ecc/bls12-377/fp/element.go | 28 +- ecc/bls12-377/fp/element_ops_amd64.go | 44 +- ecc/bls12-377/fp/element_ops_amd64.s | 296 +--- ecc/bls12-377/fp/element_ops_purego.go | 44 +- ecc/bls12-377/fp/element_test.go | 17 +- ecc/bls12-377/fr/element.go | 28 +- ecc/bls12-377/fr/element_mul_amd64.s | 490 ------- ecc/bls12-377/fr/element_ops_amd64.go | 44 +- ecc/bls12-377/fr/element_ops_amd64.s | 826 +---------- ecc/bls12-377/fr/element_ops_purego.go | 44 +- ecc/bls12-377/fr/element_test.go | 17 +- ecc/bls12-381/fp/element.go | 28 +- ecc/bls12-381/fp/element_mul_amd64.s | 857 ----------- ecc/bls12-381/fp/element_ops_amd64.go | 44 +- ecc/bls12-381/fp/element_ops_amd64.s | 296 +--- ecc/bls12-381/fp/element_ops_purego.go | 44 +- ecc/bls12-381/fp/element_test.go | 17 +- ecc/bls12-381/fr/element.go | 28 +- ecc/bls12-381/fr/element_mul_amd64.s | 490 ------- ecc/bls12-381/fr/element_ops_amd64.go | 44 +- ecc/bls12-381/fr/element_ops_amd64.s | 826 +---------- ecc/bls12-381/fr/element_ops_purego.go | 44 +- ecc/bls12-381/fr/element_test.go | 17 +- ecc/bls24-315/fp/element.go | 28 +- ecc/bls24-315/fp/element_mul_amd64.s | 656 --------- ecc/bls24-315/fp/element_ops_amd64.go | 44 +- ecc/bls24-315/fp/element_ops_amd64.s | 262 +--- ecc/bls24-315/fp/element_ops_purego.go | 44 +- ecc/bls24-315/fp/element_test.go | 17 +- ecc/bls24-315/fr/element.go | 28 +- ecc/bls24-315/fr/element_mul_amd64.s | 490 ------- ecc/bls24-315/fr/element_ops_amd64.go | 44 +- ecc/bls24-315/fr/element_ops_amd64.s | 826 +---------- ecc/bls24-315/fr/element_ops_purego.go | 44 +- ecc/bls24-315/fr/element_test.go | 17 +- ecc/bls24-317/fp/element.go | 28 +- ecc/bls24-317/fp/element_mul_amd64.s | 656 --------- ecc/bls24-317/fp/element_ops_amd64.go | 44 +- ecc/bls24-317/fp/element_ops_amd64.s | 262 +--- ecc/bls24-317/fp/element_ops_purego.go | 44 +- ecc/bls24-317/fp/element_test.go | 17 +- ecc/bls24-317/fr/element.go | 28 +- ecc/bls24-317/fr/element_mul_amd64.s | 490 ------- ecc/bls24-317/fr/element_ops_amd64.go | 44 +- ecc/bls24-317/fr/element_ops_amd64.s | 826 +---------- ecc/bls24-317/fr/element_ops_purego.go | 44 +- ecc/bls24-317/fr/element_test.go | 17 +- ecc/bn254/fp/element.go | 28 +- ecc/bn254/fp/element_mul_amd64.s | 490 ------- ecc/bn254/fp/element_ops_amd64.go | 44 +- ecc/bn254/fp/element_ops_amd64.s | 826 +---------- ecc/bn254/fp/element_ops_purego.go | 44 +- ecc/bn254/fp/element_test.go | 17 +- ecc/bn254/fr/element.go | 28 +- ecc/bn254/fr/element_mul_amd64.s | 490 ------- ecc/bn254/fr/element_ops_amd64.go | 44 +- ecc/bn254/fr/element_ops_amd64.s | 826 +---------- ecc/bn254/fr/element_ops_purego.go | 44 +- ecc/bn254/fr/element_test.go | 17 +- ecc/bw6-633/fp/element.go | 28 +- ecc/bw6-633/fp/element_ops_amd64.go | 44 +- ecc/bw6-633/fp/element_ops_amd64.s | 428 +----- ecc/bw6-633/fp/element_ops_purego.go | 44 +- ecc/bw6-633/fp/element_test.go | 17 +- ecc/bw6-633/fr/element.go | 28 +- ecc/bw6-633/fr/element_ops_amd64.go | 44 +- ecc/bw6-633/fr/element_ops_amd64.s | 262 +--- ecc/bw6-633/fr/element_ops_purego.go | 44 +- ecc/bw6-633/fr/element_test.go | 17 +- ecc/bw6-761/fp/element.go | 28 +- ecc/bw6-761/fp/element_ops_amd64.go | 44 +- ecc/bw6-761/fp/element_ops_amd64.s | 494 +------ ecc/bw6-761/fp/element_ops_purego.go | 44 +- ecc/bw6-761/fp/element_test.go | 17 +- ecc/bw6-761/fr/element.go | 28 +- ecc/bw6-761/fr/element_mul_amd64.s | 857 ----------- ecc/bw6-761/fr/element_ops_amd64.go | 44 +- ecc/bw6-761/fr/element_ops_amd64.s | 296 +--- ecc/bw6-761/fr/element_ops_purego.go | 44 +- ecc/bw6-761/fr/element_test.go | 17 +- ecc/secp256k1/fp/element.go | 28 +- ecc/secp256k1/fp/element_ops_purego.go | 52 +- ecc/secp256k1/fp/element_test.go | 17 +- ecc/secp256k1/fp/vector.go | 24 + ecc/secp256k1/fr/element.go | 28 +- ecc/secp256k1/fr/element_ops_purego.go | 52 +- ecc/secp256k1/fr/element_test.go | 17 +- ecc/secp256k1/fr/vector.go | 24 + ecc/stark-curve/fp/element.go | 28 +- ecc/stark-curve/fp/element_mul_amd64.s | 490 ------- ecc/stark-curve/fp/element_ops_amd64.go | 44 +- ecc/stark-curve/fp/element_ops_amd64.s | 826 +---------- ecc/stark-curve/fp/element_ops_purego.go | 44 +- ecc/stark-curve/fp/element_test.go | 17 +- ecc/stark-curve/fr/element.go | 28 +- ecc/stark-curve/fr/element_mul_amd64.s | 490 ------- ecc/stark-curve/fr/element_ops_amd64.go | 44 +- ecc/stark-curve/fr/element_ops_amd64.s | 826 +---------- ecc/stark-curve/fr/element_ops_purego.go | 44 +- ecc/stark-curve/fr/element_test.go | 17 +- field/asm/.gitignore | 5 + .../asm/element_10w_amd64.h | 413 +++++- .../asm/element_12w_amd64.h | 473 ++++++- field/asm/element_4w_amd64.h | 1258 +++++++++++++++++ .../asm/element_5w_amd64.h | 256 +++- .../asm/element_6w_amd64.h | 287 +++- field/generator/asm/amd64/asm_macros.go | 70 +- field/generator/asm/amd64/build.go | 90 +- .../generator/asm/amd64/element_butterfly.go | 15 +- field/generator/asm/amd64/element_frommont.go | 4 +- field/generator/asm/amd64/element_mul.go | 17 +- field/generator/asm/amd64/element_vec.go | 4 +- field/generator/config/field_config.go | 23 +- field/generator/generator.go | 57 +- field/generator/generator_test.go | 63 +- .../internal/templates/element/asm.go | 4 +- .../internal/templates/element/mul_cios.go | 49 +- .../internal/templates/element/mul_nocarry.go | 4 +- .../internal/templates/element/ops_asm.go | 2 +- .../internal/templates/element/ops_purego.go | 2 +- .../internal/templates/element/tests.go | 18 +- .../internal/templates/element/vector.go | 2 +- field/goff/cmd/root.go | 3 +- field/goldilocks/element.go | 28 +- field/goldilocks/element_test.go | 17 +- field/goldilocks/internal/main.go | 2 +- go.mod | 1 + go.sum | 1 + internal/generator/main.go | 33 +- internal/generator/tower/asm/amd64/e2.go | 10 +- .../generator/tower/asm/amd64/e2_bn254.go | 2 +- 131 files changed, 3314 insertions(+), 18419 deletions(-) delete mode 100644 ecc/bls12-377/fr/element_mul_amd64.s delete mode 100644 ecc/bls12-381/fp/element_mul_amd64.s delete mode 100644 ecc/bls12-381/fr/element_mul_amd64.s delete mode 100644 ecc/bls24-315/fp/element_mul_amd64.s delete mode 100644 ecc/bls24-315/fr/element_mul_amd64.s delete mode 100644 ecc/bls24-317/fp/element_mul_amd64.s delete mode 100644 ecc/bls24-317/fr/element_mul_amd64.s delete mode 100644 ecc/bn254/fp/element_mul_amd64.s delete mode 100644 ecc/bn254/fr/element_mul_amd64.s delete mode 100644 ecc/bw6-761/fr/element_mul_amd64.s delete mode 100644 ecc/stark-curve/fp/element_mul_amd64.s delete mode 100644 ecc/stark-curve/fr/element_mul_amd64.s create mode 100644 field/asm/.gitignore rename ecc/bw6-633/fp/element_mul_amd64.s => field/asm/element_10w_amd64.h (79%) rename ecc/bw6-761/fp/element_mul_amd64.s => field/asm/element_12w_amd64.h (81%) create mode 100644 field/asm/element_4w_amd64.h rename ecc/bw6-633/fr/element_mul_amd64.s => field/asm/element_5w_amd64.h (70%) rename ecc/bls12-377/fp/element_mul_amd64.s => field/asm/element_6w_amd64.h (72%) diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 81a730fbd..393f45744 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bls12-377/fp/element_ops_amd64.go b/ecc/bls12-377/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.go +++ b/ecc/bls12-377/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-377/fp/element_ops_amd64.s b/ecc/bls12-377/fp/element_ops_amd64.s index 7242622a4..cd10b7d6d 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.s +++ b/ecc/bls12-377/fp/element_ops_amd64.s @@ -1,21 +1,13 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x8508c00000000001 +#define q1 $0x170b5d4430000000 +#define q2 $0x1ef3622fba094800 +#define q3 $0x1a22d9f300f5138f +#define q4 $0xc63b05c06ca1493b +#define q5 $0x01ae3a4617c510ea -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_6w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x8508c00000000001 @@ -30,277 +22,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $48 DATA qInv0<>(SB)/8, $0x8508bfffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0x8508c00000000001, CX - MOVQ $0x170b5d4430000000, BX - MOVQ $0x1ef3622fba094800, SI - MOVQ $0x1a22d9f300f5138f, DI - MOVQ $0xc63b05c06ca1493b, R8 - MOVQ $0x01ae3a4617c510ea, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go index a4c3796b9..072fb87c0 100644 --- a/ecc/bls12-377/fp/element_ops_purego.go +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index a1be8c091..09698d7dc 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -723,11 +722,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 07be74489..c86a5ad75 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls12-377/fr/element_mul_amd64.s b/ecc/bls12-377/fr/element_mul_amd64.s deleted file mode 100644 index a8df29c64..000000000 --- a/ecc/bls12-377/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x0a11800000000001 -DATA q<>+8(SB)/8, $0x59aa76fed0000001 -DATA q<>+16(SB)/8, $0x60b44d1e5c37b001 -DATA q<>+24(SB)/8, $0x12ab655e9a2ca556 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a117fffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x0000000db65247b1 -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index 54a104040..3b876e622 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x0a11800000000001 +#define q1 $0x59aa76fed0000001 +#define q2 $0x60b44d1e5c37b001 +#define q3 $0x12ab655e9a2ca556 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x0a11800000000001 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x0000000db65247b1 GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x0a11800000000001, R12 - MOVQ $0x59aa76fed0000001, R13 - MOVQ $0x60b44d1e5c37b001, R14 - MOVQ $0x12ab655e9a2ca556, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x0a11800000000001, R11 - MOVQ $0x59aa76fed0000001, R12 - MOVQ $0x60b44d1e5c37b001, R13 - MOVQ $0x12ab655e9a2ca556, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go index 446bec2e2..273925197 100644 --- a/ecc/bls12-377/fr/element_ops_purego.go +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 4e3d89236..8f17f77d9 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index f5c2df0c2..f0bcfe51b 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bls12-381/fp/element_mul_amd64.s b/ecc/bls12-381/fp/element_mul_amd64.s deleted file mode 100644 index e95c98403..000000000 --- a/ecc/bls12-381/fp/element_mul_amd64.s +++ /dev/null @@ -1,857 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-381/fp/element_ops_amd64.go b/ecc/bls12-381/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.go +++ b/ecc/bls12-381/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-381/fp/element_ops_amd64.s b/ecc/bls12-381/fp/element_ops_amd64.s index 830b2dd63..f4a844f1d 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.s +++ b/ecc/bls12-381/fp/element_ops_amd64.s @@ -1,21 +1,13 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0xb9feffffffffaaab +#define q1 $0x1eabfffeb153ffff +#define q2 $0x6730d2a0f6b0f624 +#define q3 $0x64774b84f38512bf +#define q4 $0x4b1ba7b6434bacd7 +#define q5 $0x1a0111ea397fe69a -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_6w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0xb9feffffffffaaab @@ -30,277 +22,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $48 DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0xb9feffffffffaaab, CX - MOVQ $0x1eabfffeb153ffff, BX - MOVQ $0x6730d2a0f6b0f624, SI - MOVQ $0x64774b84f38512bf, DI - MOVQ $0x4b1ba7b6434bacd7, R8 - MOVQ $0x1a0111ea397fe69a, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go index fc10b3df3..ee3f7e740 100644 --- a/ecc/bls12-381/fp/element_ops_purego.go +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index cb26b5425..b9ecd277e 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -723,11 +722,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index aa6c47cdd..2c9344acd 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls12-381/fr/element_mul_amd64.s b/ecc/bls12-381/fr/element_mul_amd64.s deleted file mode 100644 index 36064570f..000000000 --- a/ecc/bls12-381/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x00000002355094ed -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index 16bc87fbf..1be64ba50 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0xffffffff00000001 +#define q1 $0x53bda402fffe5bfe +#define q2 $0x3339d80809a1d805 +#define q3 $0x73eda753299d7d48 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0xffffffff00000001 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x00000002355094ed GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0xffffffff00000001, R12 - MOVQ $0x53bda402fffe5bfe, R13 - MOVQ $0x3339d80809a1d805, R14 - MOVQ $0x73eda753299d7d48, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0xffffffff00000001, R11 - MOVQ $0x53bda402fffe5bfe, R12 - MOVQ $0x3339d80809a1d805, R13 - MOVQ $0x73eda753299d7d48, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go index 3c494940b..7bbf9c41a 100644 --- a/ecc/bls12-381/fr/element_ops_purego.go +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 68d6739f5..931c6132c 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 4d6138686..4ab67695e 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bls24-315/fp/element_mul_amd64.s b/ecc/bls24-315/fp/element_mul_amd64.s deleted file mode 100644 index 92bba4f58..000000000 --- a/ecc/bls24-315/fp/element_mul_amd64.s +++ /dev/null @@ -1,656 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-315/fp/element_ops_amd64.go b/ecc/bls24-315/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.go +++ b/ecc/bls24-315/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-315/fp/element_ops_amd64.s b/ecc/bls24-315/fp/element_ops_amd64.s index 9528ab595..87385b8d8 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.s +++ b/ecc/bls24-315/fp/element_ops_amd64.s @@ -1,21 +1,12 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x6fe802ff40300001 +#define q1 $0x421ee5da52bde502 +#define q2 $0xdec1d01aa27a1ae0 +#define q3 $0xd3f7498be97c5eaf +#define q4 $0x04c23a02b586d650 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_5w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x6fe802ff40300001 @@ -29,244 +20,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $40 DATA qInv0<>(SB)/8, $0x702ff9ff402fffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x6fe802ff40300001, CX - MOVQ $0x421ee5da52bde502, BX - MOVQ $0xdec1d01aa27a1ae0, SI - MOVQ $0xd3f7498be97c5eaf, DI - MOVQ $0x04c23a02b586d650, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bls24-315/fp/element_ops_purego.go b/ecc/bls24-315/fp/element_ops_purego.go index 9a557a358..4796fc3c5 100644 --- a/ecc/bls24-315/fp/element_ops_purego.go +++ b/ecc/bls24-315/fp/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index c48910446..17fa1f06e 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -721,11 +720,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index c24a104a6..d7fb19b10 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls24-315/fr/element_mul_amd64.s b/ecc/bls24-315/fr/element_mul_amd64.s deleted file mode 100644 index 2afbbd3fe..000000000 --- a/ecc/bls24-315/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x19d0c5fd00c00001 -DATA q<>+8(SB)/8, $0xc8c480ece644e364 -DATA q<>+16(SB)/8, $0x25fc7ec9cf927a98 -DATA q<>+24(SB)/8, $0x196deac24a9da12b -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x0000000a112d9c09 -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index 45756951f..0fddea291 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x19d0c5fd00c00001 +#define q1 $0xc8c480ece644e364 +#define q2 $0x25fc7ec9cf927a98 +#define q3 $0x196deac24a9da12b -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x19d0c5fd00c00001 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x0000000a112d9c09 GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x19d0c5fd00c00001, R12 - MOVQ $0xc8c480ece644e364, R13 - MOVQ $0x25fc7ec9cf927a98, R14 - MOVQ $0x196deac24a9da12b, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x19d0c5fd00c00001, R11 - MOVQ $0xc8c480ece644e364, R12 - MOVQ $0x25fc7ec9cf927a98, R13 - MOVQ $0x196deac24a9da12b, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go index 2149e5d8b..e1c69349f 100644 --- a/ecc/bls24-315/fr/element_ops_purego.go +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index e382132bd..1db4b00ae 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index 652a4a78e..77818de47 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bls24-317/fp/element_mul_amd64.s b/ecc/bls24-317/fp/element_mul_amd64.s deleted file mode 100644 index bfc863eeb..000000000 --- a/ecc/bls24-317/fp/element_mul_amd64.s +++ /dev/null @@ -1,656 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8d512e565dab2aab -DATA q<>+8(SB)/8, $0xd6f339e43424bf7e -DATA q<>+16(SB)/8, $0x169a61e684c73446 -DATA q<>+24(SB)/8, $0xf28fc5a0b7f9d039 -DATA q<>+32(SB)/8, $0x1058ca226f60892c -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x55b5e0028b047ffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-317/fp/element_ops_amd64.go b/ecc/bls24-317/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.go +++ b/ecc/bls24-317/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-317/fp/element_ops_amd64.s b/ecc/bls24-317/fp/element_ops_amd64.s index cb68645b3..9da947df7 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.s +++ b/ecc/bls24-317/fp/element_ops_amd64.s @@ -1,21 +1,12 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x8d512e565dab2aab +#define q1 $0xd6f339e43424bf7e +#define q2 $0x169a61e684c73446 +#define q3 $0xf28fc5a0b7f9d039 +#define q4 $0x1058ca226f60892c -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_5w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x8d512e565dab2aab @@ -29,244 +20,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $40 DATA qInv0<>(SB)/8, $0x55b5e0028b047ffd GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x8d512e565dab2aab, CX - MOVQ $0xd6f339e43424bf7e, BX - MOVQ $0x169a61e684c73446, SI - MOVQ $0xf28fc5a0b7f9d039, DI - MOVQ $0x1058ca226f60892c, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bls24-317/fp/element_ops_purego.go b/ecc/bls24-317/fp/element_ops_purego.go index aed04e01f..9f72e6f84 100644 --- a/ecc/bls24-317/fp/element_ops_purego.go +++ b/ecc/bls24-317/fp/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index da5ab0e0b..c2e057c7d 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -721,11 +720,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index bf3215dad..0f05ae910 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bls24-317/fr/element_mul_amd64.s b/ecc/bls24-317/fr/element_mul_amd64.s deleted file mode 100644 index 77d9a3fc4..000000000 --- a/ecc/bls24-317/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf000000000000001 -DATA q<>+8(SB)/8, $0x1cd1e79196bf0e7a -DATA q<>+16(SB)/8, $0xd0b097f28d83cd49 -DATA q<>+24(SB)/8, $0x443f917ea68dafc2 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xefffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x00000003c0421687 -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index 16c632592..dc747fafa 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0xf000000000000001 +#define q1 $0x1cd1e79196bf0e7a +#define q2 $0xd0b097f28d83cd49 +#define q3 $0x443f917ea68dafc2 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0xf000000000000001 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x00000003c0421687 GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0xf000000000000001, R12 - MOVQ $0x1cd1e79196bf0e7a, R13 - MOVQ $0xd0b097f28d83cd49, R14 - MOVQ $0x443f917ea68dafc2, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0xf000000000000001, R11 - MOVQ $0x1cd1e79196bf0e7a, R12 - MOVQ $0xd0b097f28d83cd49, R13 - MOVQ $0x443f917ea68dafc2, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go index f18cc97c5..c7cae5f72 100644 --- a/ecc/bls24-317/fr/element_ops_purego.go +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 84f8af2df..e0d954f8a 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index 5ba388e73..45f6861ce 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bn254/fp/element_mul_amd64.s b/ecc/bn254/fp/element_mul_amd64.s deleted file mode 100644 index 95e95e6fa..000000000 --- a/ecc/bn254/fp/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x000000054a474626 -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index 328d9c4ab..4b1a9f449 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index 44c900aad..199754c5c 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x3c208c16d87cfd47 +#define q1 $0x97816a916871ca8d +#define q2 $0xb85045b68181585d +#define q3 $0x30644e72e131a029 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x000000054a474626 GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x3c208c16d87cfd47, R12 - MOVQ $0x97816a916871ca8d, R13 - MOVQ $0xb85045b68181585d, R14 - MOVQ $0x30644e72e131a029, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x3c208c16d87cfd47, R11 - MOVQ $0x97816a916871ca8d, R12 - MOVQ $0xb85045b68181585d, R13 - MOVQ $0x30644e72e131a029, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go index 8724acf1b..6880de49a 100644 --- a/ecc/bn254/fp/element_ops_purego.go +++ b/ecc/bn254/fp/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 8c314b633..a7cf428f4 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index cda0b2c28..f7cfbf64d 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/bn254/fr/element_mul_amd64.s b/ecc/bn254/fr/element_mul_amd64.s deleted file mode 100644 index 98e98ef6b..000000000 --- a/ecc/bn254/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x43e1f593f0000001 -DATA q<>+8(SB)/8, $0x2833e84879b97091 -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xc2e1f593efffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x000000054a474626 -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 811fa1397..0c996ad32 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x43e1f593f0000001 +#define q1 $0x2833e84879b97091 +#define q2 $0xb85045b68181585d +#define q3 $0x30644e72e131a029 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x43e1f593f0000001 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x000000054a474626 GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x43e1f593f0000001, R12 - MOVQ $0x2833e84879b97091, R13 - MOVQ $0xb85045b68181585d, R14 - MOVQ $0x30644e72e131a029, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x43e1f593f0000001, R11 - MOVQ $0x2833e84879b97091, R12 - MOVQ $0xb85045b68181585d, R13 - MOVQ $0x30644e72e131a029, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go index 2d1ae1780..9533ed621 100644 --- a/ecc/bn254/fr/element_ops_purego.go +++ b/ecc/bn254/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 2fb459444..4faf2a090 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 475abd7e5..7656002f4 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -609,32 +609,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [11]uint64 var D uint64 diff --git a/ecc/bw6-633/fp/element_ops_amd64.go b/ecc/bw6-633/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.go +++ b/ecc/bw6-633/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-633/fp/element_ops_amd64.s b/ecc/bw6-633/fp/element_ops_amd64.s index 12a078963..d14f468e3 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.s +++ b/ecc/bw6-633/fp/element_ops_amd64.s @@ -1,21 +1,17 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" +#define q0 $0xd74916ea4570000d +#define q1 $0x3d369bd31147f73c +#define q2 $0xd7b5ce7ab839c225 +#define q3 $0x7e0e8850edbda407 +#define q4 $0xb8da9f5e83f57c49 +#define q5 $0x8152a6c0fadea490 +#define q6 $0x4e59769ad9bbda2f +#define q7 $0xa8fcd8c75d79d2c7 +#define q8 $0xfc1a174f01d72ab5 +#define q9 $0x0126633cc0f35f63 + +#include "../../../field/asm/element_10w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0xd74916ea4570000d @@ -34,403 +30,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $80 DATA qInv0<>(SB)/8, $0xb50f29ab0b03b13b GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - -TEXT ·reduce(SB), $56-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), $56-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), $56-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $136-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) - - MOVQ DX, s7-64(SP) - MOVQ CX, s8-72(SP) - MOVQ BX, s9-80(SP) - MOVQ SI, s10-88(SP) - MOVQ DI, s11-96(SP) - MOVQ R8, s12-104(SP) - MOVQ R9, s13-112(SP) - MOVQ R10, s14-120(SP) - MOVQ R11, s15-128(SP) - MOVQ R12, s16-136(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ s7-64(SP), DX - ADCQ s8-72(SP), CX - ADCQ s9-80(SP), BX - ADCQ s10-88(SP), SI - ADCQ s11-96(SP), DI - ADCQ s12-104(SP), R8 - ADCQ s13-112(SP), R9 - ADCQ s14-120(SP), R10 - ADCQ s15-128(SP), R11 - ADCQ s16-136(SP), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $56-16 - MOVQ b+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ a+0(FP), AX - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ R9, s3-32(SP) - MOVQ R10, s4-40(SP) - MOVQ R11, s5-48(SP) - MOVQ R12, s6-56(SP) - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ b+8(FP), AX - SUBQ 0(AX), DX - SBBQ 8(AX), CX - SBBQ 16(AX), BX - SBBQ 24(AX), SI - SBBQ 32(AX), DI - SBBQ 40(AX), R8 - SBBQ 48(AX), R9 - SBBQ 56(AX), R10 - SBBQ 64(AX), R11 - SBBQ 72(AX), R12 - JCC noReduce_1 - MOVQ $0xd74916ea4570000d, AX - ADDQ AX, DX - MOVQ $0x3d369bd31147f73c, AX - ADCQ AX, CX - MOVQ $0xd7b5ce7ab839c225, AX - ADCQ AX, BX - MOVQ $0x7e0e8850edbda407, AX - ADCQ AX, SI - MOVQ $0xb8da9f5e83f57c49, AX - ADCQ AX, DI - MOVQ $0x8152a6c0fadea490, AX - ADCQ AX, R8 - MOVQ $0x4e59769ad9bbda2f, AX - ADCQ AX, R9 - MOVQ $0xa8fcd8c75d79d2c7, AX - ADCQ AX, R10 - MOVQ $0xfc1a174f01d72ab5, AX - ADCQ AX, R11 - MOVQ $0x0126633cc0f35f63, AX - ADCQ AX, R12 - -noReduce_1: - MOVQ b+8(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, DX - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ s3-32(SP), R9 - MOVQ s4-40(SP), R10 - MOVQ s5-48(SP), R11 - MOVQ s6-56(SP), R12 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) - - MOVQ a+0(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - RET diff --git a/ecc/bw6-633/fp/element_ops_purego.go b/ecc/bw6-633/fp/element_ops_purego.go index 69c68919e..3b5d489a3 100644 --- a/ecc/bw6-633/fp/element_ops_purego.go +++ b/ecc/bw6-633/fp/element_ops_purego.go @@ -71,48 +71,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9 uint64 var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9 uint64 diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 35b897fa2..16d6f6b54 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -649,7 +649,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -731,11 +730,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index 208f672b1..8841cd342 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -499,32 +499,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [6]uint64 var D uint64 diff --git a/ecc/bw6-633/fr/element_ops_amd64.go b/ecc/bw6-633/fr/element_ops_amd64.go index e40a9caed..83d40c28c 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.go +++ b/ecc/bw6-633/fr/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-633/fr/element_ops_amd64.s b/ecc/bw6-633/fr/element_ops_amd64.s index 9528ab595..87385b8d8 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.s +++ b/ecc/bw6-633/fr/element_ops_amd64.s @@ -1,21 +1,12 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x6fe802ff40300001 +#define q1 $0x421ee5da52bde502 +#define q2 $0xdec1d01aa27a1ae0 +#define q3 $0xd3f7498be97c5eaf +#define q4 $0x04c23a02b586d650 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_5w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x6fe802ff40300001 @@ -29,244 +20,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $40 DATA qInv0<>(SB)/8, $0x702ff9ff402fffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $16-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $24-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ CX, R9 - MOVQ BX, R10 - MOVQ SI, R11 - MOVQ DI, R12 - MOVQ R8, R13 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - SUBQ 0(DX), R9 - SBBQ 8(DX), R10 - SBBQ 16(DX), R11 - SBBQ 24(DX), R12 - SBBQ 32(DX), R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - MOVQ R8, s2-24(SP) - MOVQ $0x6fe802ff40300001, CX - MOVQ $0x421ee5da52bde502, BX - MOVQ $0xdec1d01aa27a1ae0, SI - MOVQ $0xd3f7498be97c5eaf, DI - MOVQ $0x04c23a02b586d650, R8 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - ADDQ CX, R9 - ADCQ BX, R10 - ADCQ SI, R11 - ADCQ DI, R12 - ADCQ R8, R13 - MOVQ R14, CX - MOVQ R15, BX - MOVQ s0-8(SP), SI - MOVQ s1-16(SP), DI - MOVQ s2-24(SP), R8 - MOVQ R9, 0(DX) - MOVQ R10, 8(DX) - MOVQ R11, 16(DX) - MOVQ R12, 24(DX) - MOVQ R13, 32(DX) - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - RET diff --git a/ecc/bw6-633/fr/element_ops_purego.go b/ecc/bw6-633/fr/element_ops_purego.go index 34d6c54fb..4a7cdbfe4 100644 --- a/ecc/bw6-633/fr/element_ops_purego.go +++ b/ecc/bw6-633/fr/element_ops_purego.go @@ -66,48 +66,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4 uint64 var u0, u1, u2, u3, u4 uint64 diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index 19c176ae3..60ffa7a80 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -639,7 +639,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -721,11 +720,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 36232ebff..8cdd31218 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -653,32 +653,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [13]uint64 var D uint64 diff --git a/ecc/bw6-761/fp/element_ops_amd64.go b/ecc/bw6-761/fp/element_ops_amd64.go index 83bba45ae..ed2803d71 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.go +++ b/ecc/bw6-761/fp/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-761/fp/element_ops_amd64.s b/ecc/bw6-761/fp/element_ops_amd64.s index 476e9e39e..aac74663c 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.s +++ b/ecc/bw6-761/fp/element_ops_amd64.s @@ -1,21 +1,19 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" +#define q0 $0xf49d00000000008b +#define q1 $0xe6913e6870000082 +#define q2 $0x160cf8aeeaf0a437 +#define q3 $0x98a116c25667a8f8 +#define q4 $0x71dcd3dc73ebff2e +#define q5 $0x8689c8ed12f9fd90 +#define q6 $0x03cebaff25b42304 +#define q7 $0x707ba638e584e919 +#define q8 $0x528275ef8087be41 +#define q9 $0xb926186a81d14688 +#define q10 $0xd187c94004faff3e +#define q11 $0x0122e824fb83ce0a + +#include "../../../field/asm/element_12w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0xf49d00000000008b @@ -36,467 +34,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $96 DATA qInv0<>(SB)/8, $0x0a5593568fa798dd GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - MOVQ ra10, rb10; \ - SBBQ q<>+80(SB), ra10; \ - MOVQ ra11, rb11; \ - SBBQ q<>+88(SB), ra11; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - CMOVQCS rb10, ra10; \ - CMOVQCS rb11, ra11; \ - -TEXT ·reduce(SB), $88-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), $88-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), $88-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $184-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) - - MOVQ DX, s11-96(SP) - MOVQ CX, s12-104(SP) - MOVQ BX, s13-112(SP) - MOVQ SI, s14-120(SP) - MOVQ DI, s15-128(SP) - MOVQ R8, s16-136(SP) - MOVQ R9, s17-144(SP) - MOVQ R10, s18-152(SP) - MOVQ R11, s19-160(SP) - MOVQ R12, s20-168(SP) - MOVQ R13, s21-176(SP) - MOVQ R14, s22-184(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - ADCQ R12, R12 - ADCQ R13, R13 - ADCQ R14, R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ s11-96(SP), DX - ADCQ s12-104(SP), CX - ADCQ s13-112(SP), BX - ADCQ s14-120(SP), SI - ADCQ s15-128(SP), DI - ADCQ s16-136(SP), R8 - ADCQ s17-144(SP), R9 - ADCQ s18-152(SP), R10 - ADCQ s19-160(SP), R11 - ADCQ s20-168(SP), R12 - ADCQ s21-176(SP), R13 - ADCQ s22-184(SP), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $88-16 - MOVQ b+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - MOVQ a+0(FP), AX - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - ADCQ 48(AX), R9 - ADCQ 56(AX), R10 - ADCQ 64(AX), R11 - ADCQ 72(AX), R12 - ADCQ 80(AX), R13 - ADCQ 88(AX), R14 - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ R10, s6-56(SP) - MOVQ R11, s7-64(SP) - MOVQ R12, s8-72(SP) - MOVQ R13, s9-80(SP) - MOVQ R14, s10-88(SP) - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - MOVQ 48(AX), R9 - MOVQ 56(AX), R10 - MOVQ 64(AX), R11 - MOVQ 72(AX), R12 - MOVQ 80(AX), R13 - MOVQ 88(AX), R14 - MOVQ b+8(FP), AX - SUBQ 0(AX), DX - SBBQ 8(AX), CX - SBBQ 16(AX), BX - SBBQ 24(AX), SI - SBBQ 32(AX), DI - SBBQ 40(AX), R8 - SBBQ 48(AX), R9 - SBBQ 56(AX), R10 - SBBQ 64(AX), R11 - SBBQ 72(AX), R12 - SBBQ 80(AX), R13 - SBBQ 88(AX), R14 - JCC noReduce_1 - MOVQ $0xf49d00000000008b, AX - ADDQ AX, DX - MOVQ $0xe6913e6870000082, AX - ADCQ AX, CX - MOVQ $0x160cf8aeeaf0a437, AX - ADCQ AX, BX - MOVQ $0x98a116c25667a8f8, AX - ADCQ AX, SI - MOVQ $0x71dcd3dc73ebff2e, AX - ADCQ AX, DI - MOVQ $0x8689c8ed12f9fd90, AX - ADCQ AX, R8 - MOVQ $0x03cebaff25b42304, AX - ADCQ AX, R9 - MOVQ $0x707ba638e584e919, AX - ADCQ AX, R10 - MOVQ $0x528275ef8087be41, AX - ADCQ AX, R11 - MOVQ $0xb926186a81d14688, AX - ADCQ AX, R12 - MOVQ $0xd187c94004faff3e, AX - ADCQ AX, R13 - MOVQ $0x0122e824fb83ce0a, AX - ADCQ AX, R14 - -noReduce_1: - MOVQ b+8(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - MOVQ R15, DX - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ s6-56(SP), R10 - MOVQ s7-64(SP), R11 - MOVQ s8-72(SP), R12 - MOVQ s9-80(SP), R13 - MOVQ s10-88(SP), R14 - - // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) - - MOVQ a+0(FP), AX - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - MOVQ R9, 48(AX) - MOVQ R10, 56(AX) - MOVQ R11, 64(AX) - MOVQ R12, 72(AX) - MOVQ R13, 80(AX) - MOVQ R14, 88(AX) - RET diff --git a/ecc/bw6-761/fp/element_ops_purego.go b/ecc/bw6-761/fp/element_ops_purego.go index 3c1ffa245..59d6d1d52 100644 --- a/ecc/bw6-761/fp/element_ops_purego.go +++ b/ecc/bw6-761/fp/element_ops_purego.go @@ -73,48 +73,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index 07ea72041..58da93cec 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -653,7 +653,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -735,11 +734,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 3e7eacc9e..6784bc911 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -521,32 +521,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [7]uint64 var D uint64 diff --git a/ecc/bw6-761/fr/element_mul_amd64.s b/ecc/bw6-761/fr/element_mul_amd64.s deleted file mode 100644 index 1e19c4d3f..000000000 --- a/ecc/bw6-761/fr/element_mul_amd64.s +++ /dev/null @@ -1,857 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/bw6-761/fr/element_ops_amd64.go b/ecc/bw6-761/fr/element_ops_amd64.go index e40a9caed..83d40c28c 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.go +++ b/ecc/bw6-761/fr/element_ops_amd64.go @@ -50,48 +50,8 @@ func Butterfly(a, b *Element) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/bw6-761/fr/element_ops_amd64.s b/ecc/bw6-761/fr/element_ops_amd64.s index 7242622a4..cd10b7d6d 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.s +++ b/ecc/bw6-761/fr/element_ops_amd64.s @@ -1,21 +1,13 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x8508c00000000001 +#define q1 $0x170b5d4430000000 +#define q2 $0x1ef3622fba094800 +#define q3 $0x1a22d9f300f5138f +#define q4 $0xc63b05c06ca1493b +#define q5 $0x01ae3a4617c510ea -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_6w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x8508c00000000001 @@ -30,277 +22,3 @@ GLOBL q<>(SB), (RODATA+NOPTR), $48 DATA qInv0<>(SB)/8, $0x8508bfffffffffff GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) - REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), $40-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ 32(AX), DI - MOVQ 40(AX), R8 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) - - MOVQ DX, R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ R15, DX - ADCQ s0-8(SP), CX - ADCQ s1-16(SP), BX - ADCQ s2-24(SP), SI - ADCQ s3-32(SP), DI - ADCQ s4-40(SP), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - ADCQ 32(AX), DI - ADCQ 40(AX), R8 - - // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - MOVQ DI, 32(AX) - MOVQ R8, 40(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), $48-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - MOVQ R8, R14 - MOVQ R9, R15 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - ADCQ 32(DX), R8 - ADCQ 40(DX), R9 - SUBQ 0(DX), R10 - SBBQ 8(DX), R11 - SBBQ 16(DX), R12 - SBBQ 24(DX), R13 - SBBQ 32(DX), R14 - SBBQ 40(DX), R15 - MOVQ CX, s0-8(SP) - MOVQ BX, s1-16(SP) - MOVQ SI, s2-24(SP) - MOVQ DI, s3-32(SP) - MOVQ R8, s4-40(SP) - MOVQ R9, s5-48(SP) - MOVQ $0x8508c00000000001, CX - MOVQ $0x170b5d4430000000, BX - MOVQ $0x1ef3622fba094800, SI - MOVQ $0x1a22d9f300f5138f, DI - MOVQ $0xc63b05c06ca1493b, R8 - MOVQ $0x01ae3a4617c510ea, R9 - CMOVQCC AX, CX - CMOVQCC AX, BX - CMOVQCC AX, SI - CMOVQCC AX, DI - CMOVQCC AX, R8 - CMOVQCC AX, R9 - ADDQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - ADCQ R8, R14 - ADCQ R9, R15 - MOVQ s0-8(SP), CX - MOVQ s1-16(SP), BX - MOVQ s2-24(SP), SI - MOVQ s3-32(SP), DI - MOVQ s4-40(SP), R8 - MOVQ s5-48(SP), R9 - MOVQ R10, 0(DX) - MOVQ R11, 8(DX) - MOVQ R12, 16(DX) - MOVQ R13, 24(DX) - MOVQ R14, 32(DX) - MOVQ R15, 40(DX) - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - MOVQ R8, 32(AX) - MOVQ R9, 40(AX) - RET diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go index bd2d33293..bdf76428d 100644 --- a/ecc/bw6-761/fr/element_ops_purego.go +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -67,48 +67,8 @@ func reduce(z *Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3, t4, t5 uint64 var u0, u1, u2, u3, u4, u5 uint64 diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 35f8f44da..198e3afae 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -641,7 +641,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -723,11 +722,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go index 0a242dd37..d030964bb 100644 --- a/ecc/secp256k1/fp/element.go +++ b/ecc/secp256k1/fp/element.go @@ -505,32 +505,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go index 5e8497f5b..f53ffa325 100644 --- a/ecc/secp256k1/fp/element_ops_purego.go +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -57,59 +57,11 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go index a8672f7e1..e4896c975 100644 --- a/ecc/secp256k1/fp/element_test.go +++ b/ecc/secp256k1/fp/element_test.go @@ -635,7 +635,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -717,11 +716,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index 1c29dcbf2..f0db4a2c1 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -199,6 +199,30 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go index 6afe3590b..37f3277d5 100644 --- a/ecc/secp256k1/fr/element.go +++ b/ecc/secp256k1/fr/element.go @@ -505,32 +505,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go index a9a314406..ef83ea20a 100644 --- a/ecc/secp256k1/fr/element_ops_purego.go +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -57,59 +57,11 @@ func reduce(z *Element) { _reduceGeneric(z) } -// Add adds two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Add(a, b Vector) { - addVecGeneric(*vector, a, b) -} - -// Sub subtracts two vectors element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) Sub(a, b Vector) { - subVecGeneric(*vector, a, b) -} - -// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. -// It panics if the vectors don't have the same length. -func (vector *Vector) ScalarMul(a Vector, b *Element) { - scalarMulVecGeneric(*vector, a, b) -} - -// Sum computes the sum of all elements in the vector. -func (vector *Vector) Sum() (res Element) { - sumVecGeneric(&res, *vector) - return -} - // Mul z = x * y (mod q) func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go index c3f0d3ac2..3381fdfd9 100644 --- a/ecc/secp256k1/fr/element_test.go +++ b/ecc/secp256k1/fr/element_test.go @@ -635,7 +635,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -717,11 +716,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index d6a66b036..87a7e825f 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -199,6 +199,30 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// Add adds two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Add(a, b Vector) { + addVecGeneric(*vector, a, b) +} + +// Sub subtracts two vectors element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) Sub(a, b Vector) { + subVecGeneric(*vector, a, b) +} + +// ScalarMul multiplies a vector by a scalar element-wise and stores the result in self. +// It panics if the vectors don't have the same length. +func (vector *Vector) ScalarMul(a Vector, b *Element) { + scalarMulVecGeneric(*vector, a, b) +} + +// Sum computes the sum of all elements in the vector. +func (vector *Vector) Sum() (res Element) { + sumVecGeneric(&res, *vector) + return +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/stark-curve/fp/element.go b/ecc/stark-curve/fp/element.go index 7a057be06..1715d6caf 100644 --- a/ecc/stark-curve/fp/element.go +++ b/ecc/stark-curve/fp/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/stark-curve/fp/element_mul_amd64.s b/ecc/stark-curve/fp/element_mul_amd64.s deleted file mode 100644 index 36bbb8a76..000000000 --- a/ecc/stark-curve/fp/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $1 -DATA q<>+8(SB)/8, $0 -DATA q<>+16(SB)/8, $0 -DATA q<>+24(SB)/8, $0x0800000000000011 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xffffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x0000001fffffffff -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/stark-curve/fp/element_ops_amd64.go b/ecc/stark-curve/fp/element_ops_amd64.go index 328d9c4ab..4b1a9f449 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.go +++ b/ecc/stark-curve/fp/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/stark-curve/fp/element_ops_amd64.s b/ecc/stark-curve/fp/element_ops_amd64.s index 98bd7fa01..d3139153b 100644 --- a/ecc/stark-curve/fp/element_ops_amd64.s +++ b/ecc/stark-curve/fp/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x0000000000000001 +#define q1 $0x0000000000000000 +#define q2 $0x0000000000000000 +#define q3 $0x0800000000000011 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $1 @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x0000001fffffffff GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $1, R12 - MOVQ $0, R13 - MOVQ $0, R14 - MOVQ $0x0800000000000011, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $1, R11 - MOVQ $0, R12 - MOVQ $0, R13 - MOVQ $0x0800000000000011, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/stark-curve/fp/element_ops_purego.go b/ecc/stark-curve/fp/element_ops_purego.go index c7a46aa0f..6a9a78132 100644 --- a/ecc/stark-curve/fp/element_ops_purego.go +++ b/ecc/stark-curve/fp/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/stark-curve/fp/element_test.go b/ecc/stark-curve/fp/element_test.go index 5541a687d..4a6de9ab7 100644 --- a/ecc/stark-curve/fp/element_test.go +++ b/ecc/stark-curve/fp/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/ecc/stark-curve/fr/element.go b/ecc/stark-curve/fr/element.go index a7ab8e217..1e9090ab0 100644 --- a/ecc/stark-curve/fr/element.go +++ b/ecc/stark-curve/fr/element.go @@ -477,32 +477,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [5]uint64 var D uint64 diff --git a/ecc/stark-curve/fr/element_mul_amd64.s b/ecc/stark-curve/fr/element_mul_amd64.s deleted file mode 100644 index f773f8d0d..000000000 --- a/ecc/stark-curve/fr/element_mul_amd64.s +++ /dev/null @@ -1,490 +0,0 @@ -// +build !purego - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x1e66a241adc64d2f -DATA q<>+8(SB)/8, $0xb781126dcae7b232 -DATA q<>+16(SB)/8, $0xffffffffffffffff -DATA q<>+24(SB)/8, $0x0800000000000010 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xbb6b3c4ce8bde631 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -// Mu -DATA mu<>(SB)/8, $0x0000001fffffffff -GLOBL mu<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE noAdx_1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE noAdx_2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -noAdx_2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/ecc/stark-curve/fr/element_ops_amd64.go b/ecc/stark-curve/fr/element_ops_amd64.go index 49da23450..d21dabdaa 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.go +++ b/ecc/stark-curve/fr/element_ops_amd64.go @@ -103,48 +103,8 @@ func sumVec(res *Element, a *Element, n uint64) // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 mul(z, x, y) return z diff --git a/ecc/stark-curve/fr/element_ops_amd64.s b/ecc/stark-curve/fr/element_ops_amd64.s index e0149bd6c..e2c1858e8 100644 --- a/ecc/stark-curve/fr/element_ops_amd64.s +++ b/ecc/stark-curve/fr/element_ops_amd64.s @@ -1,21 +1,11 @@ // +build !purego -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +#define q0 $0x1e66a241adc64d2f +#define q1 $0xb781126dcae7b232 +#define q2 $0xffffffffffffffff +#define q3 $0x0800000000000010 -#include "textflag.h" -#include "funcdata.h" +#include "../../../field/asm/element_4w_amd64.h" // modulus q DATA q<>+0(SB)/8, $0x1e66a241adc64d2f @@ -31,809 +21,3 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 DATA mu<>(SB)/8, $0x0000001fffffffff GLOBL mu<>(SB), (RODATA+NOPTR), $8 -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// Butterfly(a, b *Element) sets a = a + b; b = a - b -TEXT ·Butterfly(SB), NOSPLIT, $0-16 - MOVQ a+0(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ CX, R8 - MOVQ BX, R9 - MOVQ SI, R10 - MOVQ DI, R11 - XORQ AX, AX - MOVQ b+8(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - SUBQ 0(DX), R8 - SBBQ 8(DX), R9 - SBBQ 16(DX), R10 - SBBQ 24(DX), R11 - MOVQ $0x1e66a241adc64d2f, R12 - MOVQ $0xb781126dcae7b232, R13 - MOVQ $0xffffffffffffffff, R14 - MOVQ $0x0800000000000010, R15 - CMOVQCC AX, R12 - CMOVQCC AX, R13 - CMOVQCC AX, R14 - CMOVQCC AX, R15 - ADDQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - ADCQ R15, R11 - MOVQ R8, 0(DX) - MOVQ R9, 8(DX) - MOVQ R10, 16(DX) - MOVQ R11, 24(DX) - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ a+0(FP), AX - MOVQ CX, 0(AX) - MOVQ BX, 8(AX) - MOVQ SI, 16(AX) - MOVQ DI, 24(AX) - RET - -// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] -TEXT ·addVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - -loop_1: - TESTQ BX, BX - JEQ done_2 // n == 0, we are done - - // a[0] -> SI - // a[1] -> DI - // a[2] -> R8 - // a[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - ADDQ 0(DX), SI - ADCQ 8(DX), DI - ADCQ 16(DX), R8 - ADCQ 24(DX), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ SI, 0(CX) - MOVQ DI, 8(CX) - MOVQ R8, 16(CX) - MOVQ R9, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_1 - -done_2: - RET - -// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] -TEXT ·subVec(SB), NOSPLIT, $0-32 - MOVQ res+0(FP), CX - MOVQ a+8(FP), AX - MOVQ b+16(FP), DX - MOVQ n+24(FP), BX - XORQ SI, SI - -loop_3: - TESTQ BX, BX - JEQ done_4 // n == 0, we are done - - // a[0] -> DI - // a[1] -> R8 - // a[2] -> R9 - // a[3] -> R10 - MOVQ 0(AX), DI - MOVQ 8(AX), R8 - MOVQ 16(AX), R9 - MOVQ 24(AX), R10 - SUBQ 0(DX), DI - SBBQ 8(DX), R8 - SBBQ 16(DX), R9 - SBBQ 24(DX), R10 - - // reduce (a-b) mod q - // q[0] -> R11 - // q[1] -> R12 - // q[2] -> R13 - // q[3] -> R14 - MOVQ $0x1e66a241adc64d2f, R11 - MOVQ $0xb781126dcae7b232, R12 - MOVQ $0xffffffffffffffff, R13 - MOVQ $0x0800000000000010, R14 - CMOVQCC SI, R11 - CMOVQCC SI, R12 - CMOVQCC SI, R13 - CMOVQCC SI, R14 - - // add registers (q or 0) to a, and set to result - ADDQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - ADCQ R14, R10 - MOVQ DI, 0(CX) - MOVQ R8, 8(CX) - MOVQ R9, 16(CX) - MOVQ R10, 24(CX) - - // increment pointers to visit next element - ADDQ $32, AX - ADDQ $32, DX - ADDQ $32, CX - DECQ BX // decrement n - JMP loop_3 - -done_4: - RET - -// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b -TEXT ·scalarMulVec(SB), $56-32 - CMPB ·supportAdx(SB), $1 - JNE noAdx_5 - MOVQ a+8(FP), R11 - MOVQ b+16(FP), R10 - MOVQ n+24(FP), R12 - - // scalar[0] -> SI - // scalar[1] -> DI - // scalar[2] -> R8 - // scalar[3] -> R9 - MOVQ 0(R10), SI - MOVQ 8(R10), DI - MOVQ 16(R10), R8 - MOVQ 24(R10), R9 - MOVQ res+0(FP), R10 - -loop_6: - TESTQ R12, R12 - JEQ done_7 // n == 0, we are done - - // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ SI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ DI, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R8, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ SI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ DI, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R8, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R9, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ R14, AX - MOVQ R13, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce t mod q - // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) - REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) - - MOVQ R14, 0(R10) - MOVQ R15, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - - // increment pointers to visit next element - ADDQ $32, R11 - ADDQ $32, R10 - DECQ R12 // decrement n - JMP loop_6 - -done_7: - RET - -noAdx_5: - MOVQ n+24(FP), DX - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ DX, 8(SP) - MOVQ DX, 16(SP) - MOVQ a+8(FP), AX - MOVQ AX, 24(SP) - MOVQ DX, 32(SP) - MOVQ DX, 40(SP) - MOVQ b+16(FP), AX - MOVQ AX, 48(SP) - CALL ·scalarMulVecGeneric(SB) - RET - -// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) -TEXT ·sumVec(SB), NOSPLIT, $0-24 - - // Derived from https://github.com/a16z/vectorized-fields - // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 - // first, we handle the case where n % 8 != 0 - // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers - // finally, we reduce the sum and store it in res - // - // when we move an element of a into a Z register, we use VPMOVZXDQ - // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] - // VPMOVZXDQ(ai, Z0) will result in - // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] - // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi - // we can safely add 2^32+1 times Z registers constructed this way without overflow - // since each of this lo/hi bits are moved into a "64bits" slot - // N = 2^64-1 / 2^32-1 = 2^32+1 - // - // we then propagate the carry using ADOXQ and ADCXQ - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - // we then reduce the sum using a single-word Barrett reduction - // we pick mu = 2^288 / q; which correspond to 4.5 words max. - // meaning we must guarantee that r4 fits in 32bits. - // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) - - MOVQ a+8(FP), R14 - MOVQ n+16(FP), R15 - - // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 - VXORPS Z0, Z0, Z0 - VMOVDQA64 Z0, Z1 - VMOVDQA64 Z0, Z2 - VMOVDQA64 Z0, Z3 - VMOVDQA64 Z0, Z4 - VMOVDQA64 Z0, Z5 - VMOVDQA64 Z0, Z6 - VMOVDQA64 Z0, Z7 - - // n % 8 -> CX - // n / 8 -> R15 - MOVQ R15, CX - ANDQ $7, CX - SHRQ $3, R15 - -loop_single_10: - TESTQ CX, CX - JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 - VPMOVZXDQ 0(R14), Z8 - VPADDQ Z8, Z0, Z0 - ADDQ $32, R14 - DECQ CX // decrement nMod8 - JMP loop_single_10 - -loop8by8_8: - TESTQ R15, R15 - JEQ accumulate_11 // n == 0, we are going to accumulate - VPMOVZXDQ 0*32(R14), Z8 - VPMOVZXDQ 1*32(R14), Z9 - VPMOVZXDQ 2*32(R14), Z10 - VPMOVZXDQ 3*32(R14), Z11 - VPMOVZXDQ 4*32(R14), Z12 - VPMOVZXDQ 5*32(R14), Z13 - VPMOVZXDQ 6*32(R14), Z14 - VPMOVZXDQ 7*32(R14), Z15 - PREFETCHT0 256(R14) - VPADDQ Z8, Z0, Z0 - VPADDQ Z9, Z1, Z1 - VPADDQ Z10, Z2, Z2 - VPADDQ Z11, Z3, Z3 - VPADDQ Z12, Z4, Z4 - VPADDQ Z13, Z5, Z5 - VPADDQ Z14, Z6, Z6 - VPADDQ Z15, Z7, Z7 - - // increment pointers to visit next 8 elements - ADDQ $256, R14 - DECQ R15 // decrement n - JMP loop8by8_8 - -accumulate_11: - // accumulate the 8 Z registers into Z0 - VPADDQ Z7, Z6, Z6 - VPADDQ Z6, Z5, Z5 - VPADDQ Z5, Z4, Z4 - VPADDQ Z4, Z3, Z3 - VPADDQ Z3, Z2, Z2 - VPADDQ Z2, Z1, Z1 - VPADDQ Z1, Z0, Z0 - - // carry propagation - // lo(w0) -> BX - // hi(w0) -> SI - // lo(w1) -> DI - // hi(w1) -> R8 - // lo(w2) -> R9 - // hi(w2) -> R10 - // lo(w3) -> R11 - // hi(w3) -> R12 - VMOVQ X0, BX - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, SI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, DI - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R8 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R9 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R10 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R11 - VALIGNQ $1, Z0, Z0, Z0 - VMOVQ X0, R12 - - // lo(hi(wo)) -> R13 - // lo(hi(w1)) -> CX - // lo(hi(w2)) -> R15 - // lo(hi(w3)) -> R14 -#define SPLIT_LO_HI(lo, hi) \ - MOVQ hi, lo; \ - ANDQ $0xffffffff, lo; \ - SHLQ $32, lo; \ - SHRQ $32, hi; \ - - SPLIT_LO_HI(R13, SI) - SPLIT_LO_HI(CX, R8) - SPLIT_LO_HI(R15, R10) - SPLIT_LO_HI(R14, R12) - - // r0 = w0l + lo(woh) - // r1 = carry + hi(woh) + w1l + lo(w1h) - // r2 = carry + hi(w1h) + w2l + lo(w2h) - // r3 = carry + hi(w2h) + w3l + lo(w3h) - // r4 = carry + hi(w3h) - - XORQ AX, AX // clear the flags - ADOXQ R13, BX - ADOXQ CX, DI - ADCXQ SI, DI - ADOXQ R15, R9 - ADCXQ R8, R9 - ADOXQ R14, R11 - ADCXQ R10, R11 - ADOXQ AX, R12 - ADCXQ AX, R12 - - // r[0] -> BX - // r[1] -> DI - // r[2] -> R9 - // r[3] -> R11 - // r[4] -> R12 - // reduce using single-word Barrett - // mu=2^288 / q -> SI - MOVQ mu<>(SB), SI - MOVQ R11, AX - SHRQ $32, R12, AX - MULQ SI // high bits of res stored in DX - MULXQ q<>+0(SB), AX, SI - SUBQ AX, BX - SBBQ SI, DI - MULXQ q<>+16(SB), AX, SI - SBBQ AX, R9 - SBBQ SI, R11 - SBBQ $0, R12 - MULXQ q<>+8(SB), AX, SI - SUBQ AX, DI - SBBQ SI, R9 - MULXQ q<>+24(SB), AX, SI - SBBQ AX, R11 - SBBQ SI, R12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - SUBQ q<>+0(SB), BX - SBBQ q<>+8(SB), DI - SBBQ q<>+16(SB), R9 - SBBQ q<>+24(SB), R11 - SBBQ $0, R12 - JCS modReduced_12 - MOVQ BX, R8 - MOVQ DI, R10 - MOVQ R9, R13 - MOVQ R11, CX - -modReduced_12: - MOVQ res+0(FP), SI - MOVQ R8, 0(SI) - MOVQ R10, 8(SI) - MOVQ R13, 16(SI) - MOVQ CX, 24(SI) - -done_9: - RET diff --git a/ecc/stark-curve/fr/element_ops_purego.go b/ecc/stark-curve/fr/element_ops_purego.go index a45314560..7a2cde14b 100644 --- a/ecc/stark-curve/fr/element_ops_purego.go +++ b/ecc/stark-curve/fr/element_ops_purego.go @@ -89,48 +89,8 @@ func (vector *Vector) Sum() (res Element) { // x and y must be less than q func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // (also described in https://eprint.iacr.org/2022/1400.pdf annex) - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t0, t1, t2, t3 uint64 var u0, u1, u2, u3 uint64 diff --git a/ecc/stark-curve/fr/element_test.go b/ecc/stark-curve/fr/element_test.go index 88dc47337..d71f25273 100644 --- a/ecc/stark-curve/fr/element_test.go +++ b/ecc/stark-curve/fr/element_test.go @@ -637,7 +637,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -719,11 +718,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/field/asm/.gitignore b/field/asm/.gitignore new file mode 100644 index 000000000..eb82b555a --- /dev/null +++ b/field/asm/.gitignore @@ -0,0 +1,5 @@ +# generated by integration tests +element_2w_amd64.h +element_3w_amd64.h +element_7w_amd64.h +element_8w_amd64.h \ No newline at end of file diff --git a/ecc/bw6-633/fp/element_mul_amd64.s b/field/asm/element_10w_amd64.h similarity index 79% rename from ecc/bw6-633/fp/element_mul_amd64.s rename to field/asm/element_10w_amd64.h index 62a7d4dda..69fd79c80 100644 --- a/ecc/bw6-633/fp/element_mul_amd64.s +++ b/field/asm/element_10w_amd64.h @@ -1,5 +1,3 @@ -// +build !purego - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,23 +15,6 @@ #include "textflag.h" #include "funcdata.h" -// modulus q -DATA q<>+0(SB)/8, $0xd74916ea4570000d -DATA q<>+8(SB)/8, $0x3d369bd31147f73c -DATA q<>+16(SB)/8, $0xd7b5ce7ab839c225 -DATA q<>+24(SB)/8, $0x7e0e8850edbda407 -DATA q<>+32(SB)/8, $0xb8da9f5e83f57c49 -DATA q<>+40(SB)/8, $0x8152a6c0fadea490 -DATA q<>+48(SB)/8, $0x4e59769ad9bbda2f -DATA q<>+56(SB)/8, $0xa8fcd8c75d79d2c7 -DATA q<>+64(SB)/8, $0xfc1a174f01d72ab5 -DATA q<>+72(SB)/8, $0x0126633cc0f35f63 -GLOBL q<>(SB), (RODATA+NOPTR), $80 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xb50f29ab0b03b13b -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - #define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ MOVQ ra0, rb0; \ SUBQ q<>(SB), ra0; \ @@ -66,24 +47,384 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb8, ra8; \ CMOVQCS rb9, ra9; \ +TEXT ·reduce(SB), $56-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), $56-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), $56-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $136-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP)) + + MOVQ DX, s7-64(SP) + MOVQ CX, s8-72(SP) + MOVQ BX, s9-80(SP) + MOVQ SI, s10-88(SP) + MOVQ DI, s11-96(SP) + MOVQ R8, s12-104(SP) + MOVQ R9, s13-112(SP) + MOVQ R10, s14-120(SP) + MOVQ R11, s15-128(SP) + MOVQ R12, s16-136(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ s7-64(SP), DX + ADCQ s8-72(SP), CX + ADCQ s9-80(SP), BX + ADCQ s10-88(SP), SI + ADCQ s11-96(SP), DI + ADCQ s12-104(SP), R8 + ADCQ s13-112(SP), R9 + ADCQ s14-120(SP), R10 + ADCQ s15-128(SP), R11 + ADCQ s16-136(SP), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $56-16 + MOVQ b+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ a+0(FP), AX + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + MOVQ DX, R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + MOVQ R8, s2-24(SP) + MOVQ R9, s3-32(SP) + MOVQ R10, s4-40(SP) + MOVQ R11, s5-48(SP) + MOVQ R12, s6-56(SP) + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ b+8(FP), AX + SUBQ 0(AX), DX + SBBQ 8(AX), CX + SBBQ 16(AX), BX + SBBQ 24(AX), SI + SBBQ 32(AX), DI + SBBQ 40(AX), R8 + SBBQ 48(AX), R9 + SBBQ 56(AX), R10 + SBBQ 64(AX), R11 + SBBQ 72(AX), R12 + JCC noReduce_1 + MOVQ q0, AX + ADDQ AX, DX + MOVQ q1, AX + ADCQ AX, CX + MOVQ q2, AX + ADCQ AX, BX + MOVQ q3, AX + ADCQ AX, SI + MOVQ q4, AX + ADCQ AX, DI + MOVQ q5, AX + ADCQ AX, R8 + MOVQ q6, AX + ADCQ AX, R9 + MOVQ q7, AX + ADCQ AX, R10 + MOVQ q8, AX + ADCQ AX, R11 + MOVQ q9, AX + ADCQ AX, R12 + +noReduce_1: + MOVQ b+8(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, DX + MOVQ R14, CX + MOVQ R15, BX + MOVQ s0-8(SP), SI + MOVQ s1-16(SP), DI + MOVQ s2-24(SP), R8 + MOVQ s3-32(SP), R9 + MOVQ s4-40(SP), R10 + MOVQ s5-48(SP), R11 + MOVQ s6-56(SP), R12 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP)) + + MOVQ a+0(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + RET + // mul(res, x, y *Element) TEXT ·mul(SB), $64-24 - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE noAdx_1 + JNE noAdx_2 MOVQ x+8(FP), R12 MOVQ y+16(FP), R13 @@ -1323,7 +1664,7 @@ TEXT ·mul(SB), $64-24 MOVQ R11, 72(AX) RET -noAdx_1: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -1336,8 +1677,8 @@ TEXT ·mul(SB), $64-24 TEXT ·fromMont(SB), $64-8 NO_LOCAL_POINTERS - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] @@ -1348,7 +1689,7 @@ TEXT ·fromMont(SB), $64-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE noAdx_2 + JNE noAdx_3 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -1967,7 +2308,7 @@ TEXT ·fromMont(SB), $64-8 MOVQ R11, 72(AX) RET -noAdx_2: +noAdx_3: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/ecc/bw6-761/fp/element_mul_amd64.s b/field/asm/element_12w_amd64.h similarity index 81% rename from ecc/bw6-761/fp/element_mul_amd64.s rename to field/asm/element_12w_amd64.h index fd48d8606..b331dd8c6 100644 --- a/ecc/bw6-761/fp/element_mul_amd64.s +++ b/field/asm/element_12w_amd64.h @@ -1,5 +1,3 @@ -// +build !purego - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,25 +15,6 @@ #include "textflag.h" #include "funcdata.h" -// modulus q -DATA q<>+0(SB)/8, $0xf49d00000000008b -DATA q<>+8(SB)/8, $0xe6913e6870000082 -DATA q<>+16(SB)/8, $0x160cf8aeeaf0a437 -DATA q<>+24(SB)/8, $0x98a116c25667a8f8 -DATA q<>+32(SB)/8, $0x71dcd3dc73ebff2e -DATA q<>+40(SB)/8, $0x8689c8ed12f9fd90 -DATA q<>+48(SB)/8, $0x03cebaff25b42304 -DATA q<>+56(SB)/8, $0x707ba638e584e919 -DATA q<>+64(SB)/8, $0x528275ef8087be41 -DATA q<>+72(SB)/8, $0xb926186a81d14688 -DATA q<>+80(SB)/8, $0xd187c94004faff3e -DATA q<>+88(SB)/8, $0x0122e824fb83ce0a -GLOBL q<>(SB), (RODATA+NOPTR), $96 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a5593568fa798dd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - #define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ MOVQ ra0, rb0; \ SUBQ q<>(SB), ra0; \ @@ -74,24 +53,442 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb10, ra10; \ CMOVQCS rb11, ra11; \ +TEXT ·reduce(SB), $88-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), $88-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), $88-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $184-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,s11-96(SP),s12-104(SP),s13-112(SP),s14-120(SP),s15-128(SP),s16-136(SP),s17-144(SP),s18-152(SP),s19-160(SP),s20-168(SP),s21-176(SP),s22-184(SP)) + + MOVQ DX, s11-96(SP) + MOVQ CX, s12-104(SP) + MOVQ BX, s13-112(SP) + MOVQ SI, s14-120(SP) + MOVQ DI, s15-128(SP) + MOVQ R8, s16-136(SP) + MOVQ R9, s17-144(SP) + MOVQ R10, s18-152(SP) + MOVQ R11, s19-160(SP) + MOVQ R12, s20-168(SP) + MOVQ R13, s21-176(SP) + MOVQ R14, s22-184(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + ADCQ R9, R9 + ADCQ R10, R10 + ADCQ R11, R11 + ADCQ R12, R12 + ADCQ R13, R13 + ADCQ R14, R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ s11-96(SP), DX + ADCQ s12-104(SP), CX + ADCQ s13-112(SP), BX + ADCQ s14-120(SP), SI + ADCQ s15-128(SP), DI + ADCQ s16-136(SP), R8 + ADCQ s17-144(SP), R9 + ADCQ s18-152(SP), R10 + ADCQ s19-160(SP), R11 + ADCQ s20-168(SP), R12 + ADCQ s21-176(SP), R13 + ADCQ s22-184(SP), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $88-16 + MOVQ b+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + MOVQ a+0(FP), AX + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + ADCQ 48(AX), R9 + ADCQ 56(AX), R10 + ADCQ 64(AX), R11 + ADCQ 72(AX), R12 + ADCQ 80(AX), R13 + ADCQ 88(AX), R14 + MOVQ DX, R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + MOVQ R9, s5-48(SP) + MOVQ R10, s6-56(SP) + MOVQ R11, s7-64(SP) + MOVQ R12, s8-72(SP) + MOVQ R13, s9-80(SP) + MOVQ R14, s10-88(SP) + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + MOVQ 48(AX), R9 + MOVQ 56(AX), R10 + MOVQ 64(AX), R11 + MOVQ 72(AX), R12 + MOVQ 80(AX), R13 + MOVQ 88(AX), R14 + MOVQ b+8(FP), AX + SUBQ 0(AX), DX + SBBQ 8(AX), CX + SBBQ 16(AX), BX + SBBQ 24(AX), SI + SBBQ 32(AX), DI + SBBQ 40(AX), R8 + SBBQ 48(AX), R9 + SBBQ 56(AX), R10 + SBBQ 64(AX), R11 + SBBQ 72(AX), R12 + SBBQ 80(AX), R13 + SBBQ 88(AX), R14 + JCC noReduce_1 + MOVQ q0, AX + ADDQ AX, DX + MOVQ q1, AX + ADCQ AX, CX + MOVQ q2, AX + ADCQ AX, BX + MOVQ q3, AX + ADCQ AX, SI + MOVQ q4, AX + ADCQ AX, DI + MOVQ q5, AX + ADCQ AX, R8 + MOVQ q6, AX + ADCQ AX, R9 + MOVQ q7, AX + ADCQ AX, R10 + MOVQ q8, AX + ADCQ AX, R11 + MOVQ q9, AX + ADCQ AX, R12 + MOVQ q10, AX + ADCQ AX, R13 + MOVQ q11, AX + ADCQ AX, R14 + +noReduce_1: + MOVQ b+8(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + MOVQ R15, DX + MOVQ s0-8(SP), CX + MOVQ s1-16(SP), BX + MOVQ s2-24(SP), SI + MOVQ s3-32(SP), DI + MOVQ s4-40(SP), R8 + MOVQ s5-48(SP), R9 + MOVQ s6-56(SP), R10 + MOVQ s7-64(SP), R11 + MOVQ s8-72(SP), R12 + MOVQ s9-80(SP), R13 + MOVQ s10-88(SP), R14 + + // reduce element(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP)) + + MOVQ a+0(FP), AX + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + MOVQ R9, 48(AX) + MOVQ R10, 56(AX) + MOVQ R11, 64(AX) + MOVQ R12, 72(AX) + MOVQ R13, 80(AX) + MOVQ R14, 88(AX) + RET + // mul(res, x, y *Element) TEXT ·mul(SB), $96-24 - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 - JNE noAdx_1 + JNE noAdx_2 MOVQ x+8(FP), AX // x[0] -> s0-8(SP) @@ -1865,7 +2262,7 @@ TEXT ·mul(SB), $96-24 MOVQ R13, 88(AX) RET -noAdx_1: +noAdx_2: MOVQ res+0(FP), AX MOVQ AX, (SP) MOVQ x+8(FP), AX @@ -1878,8 +2275,8 @@ TEXT ·mul(SB), $96-24 TEXT ·fromMont(SB), $96-8 NO_LOCAL_POINTERS - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] @@ -1890,7 +2287,7 @@ TEXT ·fromMont(SB), $96-8 // (C,t[j-1]) := t[j] + m*q[j] + C // t[N-1] = C CMPB ·supportAdx(SB), $1 - JNE noAdx_2 + JNE noAdx_3 MOVQ res+0(FP), DX MOVQ 0(DX), R14 MOVQ 8(DX), R15 @@ -2751,7 +3148,7 @@ TEXT ·fromMont(SB), $96-8 MOVQ R13, 88(AX) RET -noAdx_2: +noAdx_3: MOVQ res+0(FP), AX MOVQ AX, (SP) CALL ·_fromMontGeneric(SB) diff --git a/field/asm/element_4w_amd64.h b/field/asm/element_4w_amd64.h new file mode 100644 index 000000000..a1fa17f7b --- /dev/null +++ b/field/asm/element_4w_amd64.h @@ -0,0 +1,1258 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" +#include "funcdata.h" + +#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ + MOVQ ra0, rb0; \ + SUBQ q<>(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ q<>+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ q<>+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ q<>+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) + REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, R11 + MOVQ CX, R12 + MOVQ BX, R13 + MOVQ SI, R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ R11, DX + ADCQ R12, CX + ADCQ R13, BX + ADCQ R14, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), NOSPLIT, $0-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ CX, R8 + MOVQ BX, R9 + MOVQ SI, R10 + MOVQ DI, R11 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + SUBQ 0(DX), R8 + SBBQ 8(DX), R9 + SBBQ 16(DX), R10 + SBBQ 24(DX), R11 + MOVQ q0, R12 + MOVQ q1, R13 + MOVQ q2, R14 + MOVQ q3, R15 + CMOVQCC AX, R12 + CMOVQCC AX, R13 + CMOVQCC AX, R14 + CMOVQCC AX, R15 + ADDQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + ADCQ R15, R11 + MOVQ R8, 0(DX) + MOVQ R9, 8(DX) + MOVQ R10, 16(DX) + MOVQ R11, 24(DX) + + // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + RET + +// addVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] + b[0...n] +TEXT ·addVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + +loop_1: + TESTQ BX, BX + JEQ done_2 // n == 0, we are done + + // a[0] -> SI + // a[1] -> DI + // a[2] -> R8 + // a[3] -> R9 + MOVQ 0(AX), SI + MOVQ 8(AX), DI + MOVQ 16(AX), R8 + MOVQ 24(AX), R9 + ADDQ 0(DX), SI + ADCQ 8(DX), DI + ADCQ 16(DX), R8 + ADCQ 24(DX), R9 + + // reduce element(SI,DI,R8,R9) using temp registers (R10,R11,R12,R13) + REDUCE(SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ SI, 0(CX) + MOVQ DI, 8(CX) + MOVQ R8, 16(CX) + MOVQ R9, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_1 + +done_2: + RET + +// subVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] - b[0...n] +TEXT ·subVec(SB), NOSPLIT, $0-32 + MOVQ res+0(FP), CX + MOVQ a+8(FP), AX + MOVQ b+16(FP), DX + MOVQ n+24(FP), BX + XORQ SI, SI + +loop_3: + TESTQ BX, BX + JEQ done_4 // n == 0, we are done + + // a[0] -> DI + // a[1] -> R8 + // a[2] -> R9 + // a[3] -> R10 + MOVQ 0(AX), DI + MOVQ 8(AX), R8 + MOVQ 16(AX), R9 + MOVQ 24(AX), R10 + SUBQ 0(DX), DI + SBBQ 8(DX), R8 + SBBQ 16(DX), R9 + SBBQ 24(DX), R10 + + // reduce (a-b) mod q + // q[0] -> R11 + // q[1] -> R12 + // q[2] -> R13 + // q[3] -> R14 + MOVQ q0, R11 + MOVQ q1, R12 + MOVQ q2, R13 + MOVQ q3, R14 + CMOVQCC SI, R11 + CMOVQCC SI, R12 + CMOVQCC SI, R13 + CMOVQCC SI, R14 + + // add registers (q or 0) to a, and set to result + ADDQ R11, DI + ADCQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + MOVQ DI, 0(CX) + MOVQ R8, 8(CX) + MOVQ R9, 16(CX) + MOVQ R10, 24(CX) + + // increment pointers to visit next element + ADDQ $32, AX + ADDQ $32, DX + ADDQ $32, CX + DECQ BX // decrement n + JMP loop_3 + +done_4: + RET + +// scalarMulVec(res, a, b *Element, n uint64) res[0...n] = a[0...n] * b +TEXT ·scalarMulVec(SB), $56-32 + CMPB ·supportAdx(SB), $1 + JNE noAdx_5 + MOVQ a+8(FP), R11 + MOVQ b+16(FP), R10 + MOVQ n+24(FP), R12 + + // scalar[0] -> SI + // scalar[1] -> DI + // scalar[2] -> R8 + // scalar[3] -> R9 + MOVQ 0(R10), SI + MOVQ 8(R10), DI + MOVQ 16(R10), R8 + MOVQ 24(R10), R9 + MOVQ res+0(FP), R10 + +loop_6: + TESTQ R12, R12 + JEQ done_7 // n == 0, we are done + + // TODO @gbotrel this is generated from the same macro as the unit mul, we should refactor this in a single asm function + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ SI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ DI, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R8, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ SI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ DI, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R8, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R9, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R13 + ADCXQ R14, AX + MOVQ R13, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce t mod q + // reduce element(R14,R15,CX,BX) using temp registers (R13,AX,DX,s0-8(SP)) + REDUCE(R14,R15,CX,BX,R13,AX,DX,s0-8(SP)) + + MOVQ R14, 0(R10) + MOVQ R15, 8(R10) + MOVQ CX, 16(R10) + MOVQ BX, 24(R10) + + // increment pointers to visit next element + ADDQ $32, R11 + ADDQ $32, R10 + DECQ R12 // decrement n + JMP loop_6 + +done_7: + RET + +noAdx_5: + MOVQ n+24(FP), DX + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ DX, 8(SP) + MOVQ DX, 16(SP) + MOVQ a+8(FP), AX + MOVQ AX, 24(SP) + MOVQ DX, 32(SP) + MOVQ DX, 40(SP) + MOVQ b+16(FP), AX + MOVQ AX, 48(SP) + CALL ·scalarMulVecGeneric(SB) + RET + +// sumVec(res, a *Element, n uint64) res = sum(a[0...n]) +TEXT ·sumVec(SB), NOSPLIT, $0-24 + + // Derived from https://github.com/a16z/vectorized-fields + // The idea is to use Z registers to accumulate the sum of elements, 8 by 8 + // first, we handle the case where n % 8 != 0 + // then, we loop over the elements 8 by 8 and accumulate the sum in the Z registers + // finally, we reduce the sum and store it in res + // + // when we move an element of a into a Z register, we use VPMOVZXDQ + // let's note w0...w3 the 4 64bits words of ai: w0 = ai[0], w1 = ai[1], w2 = ai[2], w3 = ai[3] + // VPMOVZXDQ(ai, Z0) will result in + // Z0= [hi(w3), lo(w3), hi(w2), lo(w2), hi(w1), lo(w1), hi(w0), lo(w0)] + // with hi(wi) the high 32 bits of wi and lo(wi) the low 32 bits of wi + // we can safely add 2^32+1 times Z registers constructed this way without overflow + // since each of this lo/hi bits are moved into a "64bits" slot + // N = 2^64-1 / 2^32-1 = 2^32+1 + // + // we then propagate the carry using ADOXQ and ADCXQ + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + // we then reduce the sum using a single-word Barrett reduction + // we pick mu = 2^288 / q; which correspond to 4.5 words max. + // meaning we must guarantee that r4 fits in 32bits. + // To do so, we reduce N to 2^32-1 (since r4 receives 2 carries max) + + MOVQ a+8(FP), R14 + MOVQ n+16(FP), R15 + + // initialize accumulators Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7 + VXORPS Z0, Z0, Z0 + VMOVDQA64 Z0, Z1 + VMOVDQA64 Z0, Z2 + VMOVDQA64 Z0, Z3 + VMOVDQA64 Z0, Z4 + VMOVDQA64 Z0, Z5 + VMOVDQA64 Z0, Z6 + VMOVDQA64 Z0, Z7 + + // n % 8 -> CX + // n / 8 -> R15 + MOVQ R15, CX + ANDQ $7, CX + SHRQ $3, R15 + +loop_single_10: + TESTQ CX, CX + JEQ loop8by8_8 // n % 8 == 0, we are going to loop over 8 by 8 + VPMOVZXDQ 0(R14), Z8 + VPADDQ Z8, Z0, Z0 + ADDQ $32, R14 + DECQ CX // decrement nMod8 + JMP loop_single_10 + +loop8by8_8: + TESTQ R15, R15 + JEQ accumulate_11 // n == 0, we are going to accumulate + VPMOVZXDQ 0*32(R14), Z8 + VPMOVZXDQ 1*32(R14), Z9 + VPMOVZXDQ 2*32(R14), Z10 + VPMOVZXDQ 3*32(R14), Z11 + VPMOVZXDQ 4*32(R14), Z12 + VPMOVZXDQ 5*32(R14), Z13 + VPMOVZXDQ 6*32(R14), Z14 + VPMOVZXDQ 7*32(R14), Z15 + PREFETCHT0 256(R14) + VPADDQ Z8, Z0, Z0 + VPADDQ Z9, Z1, Z1 + VPADDQ Z10, Z2, Z2 + VPADDQ Z11, Z3, Z3 + VPADDQ Z12, Z4, Z4 + VPADDQ Z13, Z5, Z5 + VPADDQ Z14, Z6, Z6 + VPADDQ Z15, Z7, Z7 + + // increment pointers to visit next 8 elements + ADDQ $256, R14 + DECQ R15 // decrement n + JMP loop8by8_8 + +accumulate_11: + // accumulate the 8 Z registers into Z0 + VPADDQ Z7, Z6, Z6 + VPADDQ Z6, Z5, Z5 + VPADDQ Z5, Z4, Z4 + VPADDQ Z4, Z3, Z3 + VPADDQ Z3, Z2, Z2 + VPADDQ Z2, Z1, Z1 + VPADDQ Z1, Z0, Z0 + + // carry propagation + // lo(w0) -> BX + // hi(w0) -> SI + // lo(w1) -> DI + // hi(w1) -> R8 + // lo(w2) -> R9 + // hi(w2) -> R10 + // lo(w3) -> R11 + // hi(w3) -> R12 + VMOVQ X0, BX + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, SI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, DI + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R8 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R9 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R10 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R11 + VALIGNQ $1, Z0, Z0, Z0 + VMOVQ X0, R12 + + // lo(hi(wo)) -> R13 + // lo(hi(w1)) -> CX + // lo(hi(w2)) -> R15 + // lo(hi(w3)) -> R14 +#define SPLIT_LO_HI(lo, hi) \ + MOVQ hi, lo; \ + ANDQ $0xffffffff, lo; \ + SHLQ $32, lo; \ + SHRQ $32, hi; \ + + SPLIT_LO_HI(R13, SI) + SPLIT_LO_HI(CX, R8) + SPLIT_LO_HI(R15, R10) + SPLIT_LO_HI(R14, R12) + + // r0 = w0l + lo(woh) + // r1 = carry + hi(woh) + w1l + lo(w1h) + // r2 = carry + hi(w1h) + w2l + lo(w2h) + // r3 = carry + hi(w2h) + w3l + lo(w3h) + // r4 = carry + hi(w3h) + + XORQ AX, AX // clear the flags + ADOXQ R13, BX + ADOXQ CX, DI + ADCXQ SI, DI + ADOXQ R15, R9 + ADCXQ R8, R9 + ADOXQ R14, R11 + ADCXQ R10, R11 + ADOXQ AX, R12 + ADCXQ AX, R12 + + // r[0] -> BX + // r[1] -> DI + // r[2] -> R9 + // r[3] -> R11 + // r[4] -> R12 + // reduce using single-word Barrett + // mu=2^288 / q -> SI + MOVQ mu<>(SB), SI + MOVQ R11, AX + SHRQ $32, R12, AX + MULQ SI // high bits of res stored in DX + MULXQ q<>+0(SB), AX, SI + SUBQ AX, BX + SBBQ SI, DI + MULXQ q<>+16(SB), AX, SI + SBBQ AX, R9 + SBBQ SI, R11 + SBBQ $0, R12 + MULXQ q<>+8(SB), AX, SI + SUBQ AX, DI + SBBQ SI, R9 + MULXQ q<>+24(SB), AX, SI + SBBQ AX, R11 + SBBQ SI, R12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + SUBQ q<>+0(SB), BX + SBBQ q<>+8(SB), DI + SBBQ q<>+16(SB), R9 + SBBQ q<>+24(SB), R11 + SBBQ $0, R12 + JCS modReduced_12 + MOVQ BX, R8 + MOVQ DI, R10 + MOVQ R9, R13 + MOVQ R11, CX + +modReduced_12: + MOVQ res+0(FP), SI + MOVQ R8, 0(SI) + MOVQ R10, 8(SI) + MOVQ R13, 16(SI) + MOVQ CX, 24(SI) + +done_9: + RET + +// mul(res, x, y *Element) +TEXT ·mul(SB), $24-24 + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE noAdx_13 + MOVQ x+8(FP), SI + + // x[0] -> DI + // x[1] -> R8 + // x[2] -> R9 + // x[3] -> R10 + MOVQ 0(SI), DI + MOVQ 8(SI), R8 + MOVQ 16(SI), R9 + MOVQ 24(SI), R10 + MOVQ y+16(FP), R11 + + // A -> BP + // t[0] -> R14 + // t[1] -> R13 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ DI, R14, R13 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ R8, AX, CX + ADOXQ AX, R13 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R9, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R13 + MULXQ R8, AX, BP + ADOXQ AX, R13 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R13 + MULXQ R8, AX, BP + ADOXQ AX, R13 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R13 + MULXQ R8, AX, BP + ADOXQ AX, R13 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) + REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +noAdx_13: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $8-8 + NO_LOCAL_POINTERS + + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE noAdx_14 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R13 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R13, R14 + MULXQ q<>+8(SB), AX, R13 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R13 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R13 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + + // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) + REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R13, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +noAdx_14: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/ecc/bw6-633/fr/element_mul_amd64.s b/field/asm/element_5w_amd64.h similarity index 70% rename from ecc/bw6-633/fr/element_mul_amd64.s rename to field/asm/element_5w_amd64.h index 92bba4f58..bfc4c176b 100644 --- a/ecc/bw6-633/fr/element_mul_amd64.s +++ b/field/asm/element_5w_amd64.h @@ -1,5 +1,3 @@ -// +build !purego - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,18 +15,6 @@ #include "textflag.h" #include "funcdata.h" -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - #define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ MOVQ ra0, rb0; \ SUBQ q<>(SB), ra0; \ @@ -46,20 +32,236 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb3, ra3; \ CMOVQCS rb4, ra4; \ +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,R8,R9) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,R8,R9) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $16-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) + REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) + + MOVQ DX, R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ R13, DX + ADCQ R14, CX + ADCQ R15, BX + ADCQ s0-8(SP), SI + ADCQ s1-16(SP), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + + // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $24-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ 32(AX), R8 + MOVQ CX, R9 + MOVQ BX, R10 + MOVQ SI, R11 + MOVQ DI, R12 + MOVQ R8, R13 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + ADCQ 32(DX), R8 + SUBQ 0(DX), R9 + SBBQ 8(DX), R10 + SBBQ 16(DX), R11 + SBBQ 24(DX), R12 + SBBQ 32(DX), R13 + MOVQ CX, R14 + MOVQ BX, R15 + MOVQ SI, s0-8(SP) + MOVQ DI, s1-16(SP) + MOVQ R8, s2-24(SP) + MOVQ q0, CX + MOVQ q1, BX + MOVQ q2, SI + MOVQ q3, DI + MOVQ q4, R8 + CMOVQCC AX, CX + CMOVQCC AX, BX + CMOVQCC AX, SI + CMOVQCC AX, DI + CMOVQCC AX, R8 + ADDQ CX, R9 + ADCQ BX, R10 + ADCQ SI, R11 + ADCQ DI, R12 + ADCQ R8, R13 + MOVQ R14, CX + MOVQ R15, BX + MOVQ s0-8(SP), SI + MOVQ s1-16(SP), DI + MOVQ s2-24(SP), R8 + MOVQ R9, 0(DX) + MOVQ R10, 8(DX) + MOVQ R11, 16(DX) + MOVQ R12, 24(DX) + MOVQ R13, 32(DX) + + // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + MOVQ R8, 32(AX) + RET + // mul(res, x, y *Element) TEXT ·mul(SB), $24-24 - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 @@ -448,8 +650,8 @@ TEXT ·mul(SB), $24-24 TEXT ·fromMont(SB), $8-8 NO_LOCAL_POINTERS - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] diff --git a/ecc/bls12-377/fp/element_mul_amd64.s b/field/asm/element_6w_amd64.h similarity index 72% rename from ecc/bls12-377/fp/element_mul_amd64.s rename to field/asm/element_6w_amd64.h index 1e19c4d3f..ee3ebf51c 100644 --- a/ecc/bls12-377/fp/element_mul_amd64.s +++ b/field/asm/element_6w_amd64.h @@ -1,5 +1,3 @@ -// +build !purego - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,19 +15,6 @@ #include "textflag.h" #include "funcdata.h" -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - #define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ MOVQ ra0, rb0; \ SUBQ q<>(SB), ra0; \ @@ -50,20 +35,266 @@ GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 CMOVQCS rb4, ra4; \ CMOVQCS rb5, ra5; \ +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) + REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,R9,R10,R11,R12,R13) + REDUCE(DX,CX,BX,SI,DI,R8,R15,R9,R10,R11,R12,R13) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11,R12) + REDUCE(DX,CX,BX,SI,DI,R8,R14,R15,R9,R10,R11,R12) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), $40-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ 32(AX), DI + MOVQ 40(AX), R8 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) + REDUCE(DX,CX,BX,SI,DI,R8,R15,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP)) + + MOVQ DX, R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + ADCQ DI, DI + ADCQ R8, R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ R15, DX + ADCQ s0-8(SP), CX + ADCQ s1-16(SP), BX + ADCQ s2-24(SP), SI + ADCQ s3-32(SP), DI + ADCQ s4-40(SP), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + ADCQ 32(AX), DI + ADCQ 40(AX), R8 + + // reduce element(DX,CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + MOVQ DI, 32(AX) + MOVQ R8, 40(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), $48-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ 32(AX), R8 + MOVQ 40(AX), R9 + MOVQ CX, R10 + MOVQ BX, R11 + MOVQ SI, R12 + MOVQ DI, R13 + MOVQ R8, R14 + MOVQ R9, R15 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + ADCQ 32(DX), R8 + ADCQ 40(DX), R9 + SUBQ 0(DX), R10 + SBBQ 8(DX), R11 + SBBQ 16(DX), R12 + SBBQ 24(DX), R13 + SBBQ 32(DX), R14 + SBBQ 40(DX), R15 + MOVQ CX, s0-8(SP) + MOVQ BX, s1-16(SP) + MOVQ SI, s2-24(SP) + MOVQ DI, s3-32(SP) + MOVQ R8, s4-40(SP) + MOVQ R9, s5-48(SP) + MOVQ q0, CX + MOVQ q1, BX + MOVQ q2, SI + MOVQ q3, DI + MOVQ q4, R8 + MOVQ q5, R9 + CMOVQCC AX, CX + CMOVQCC AX, BX + CMOVQCC AX, SI + CMOVQCC AX, DI + CMOVQCC AX, R8 + CMOVQCC AX, R9 + ADDQ CX, R10 + ADCQ BX, R11 + ADCQ SI, R12 + ADCQ DI, R13 + ADCQ R8, R14 + ADCQ R9, R15 + MOVQ s0-8(SP), CX + MOVQ s1-16(SP), BX + MOVQ s2-24(SP), SI + MOVQ s3-32(SP), DI + MOVQ s4-40(SP), R8 + MOVQ s5-48(SP), R9 + MOVQ R10, 0(DX) + MOVQ R11, 8(DX) + MOVQ R12, 16(DX) + MOVQ R13, 24(DX) + MOVQ R14, 32(DX) + MOVQ R15, 40(DX) + + // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + MOVQ R8, 32(AX) + MOVQ R9, 40(AX) + RET + // mul(res, x, y *Element) TEXT ·mul(SB), $24-24 - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 NO_LOCAL_POINTERS CMPB ·supportAdx(SB), $1 @@ -583,8 +814,8 @@ TEXT ·mul(SB), $24-24 TEXT ·fromMont(SB), $8-8 NO_LOCAL_POINTERS - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] diff --git a/field/generator/asm/amd64/asm_macros.go b/field/generator/asm/amd64/asm_macros.go index 676ff0fb7..2397ee089 100644 --- a/field/generator/asm/amd64/asm_macros.go +++ b/field/generator/asm/amd64/asm_macros.go @@ -20,6 +20,7 @@ import ( "text/template" "github.com/consensys/bavard/amd64" + "github.com/consensys/gnark-crypto/field/generator/config" ) // LabelRegisters write comment with friendler name to registers @@ -62,7 +63,7 @@ func (f *FFAmd64) ReduceElement(t, scratch []amd64.Register) { } // TODO @gbotrel: figure out if interleaving MOVQ and SUBQ or CMOVQ and MOVQ instructions makes sense -const tmplDefines = ` +const tmplDefinesDeprecated = ` // modulus q {{- range $i, $w := .Q}} @@ -74,7 +75,7 @@ GLOBL q<>(SB), (RODATA+NOPTR), ${{mul 8 $.NbWords}} DATA qInv0<>(SB)/8, {{$qinv0 := index .QInverse 0}}{{imm $qinv0}} GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 -{{- if eq .NbWords 4}} +{{- if .ASMVector }} // Mu DATA mu<>(SB)/8, {{imm .Mu}} GLOBL mu<>(SB), (RODATA+NOPTR), $8 @@ -91,13 +92,60 @@ GLOBL mu<>(SB), (RODATA+NOPTR), $8 {{- range $i := .NbWordsIndexesFull}} CMOVQCS rb{{$i}}, ra{{$i}}; \ {{- end}} +` + +const tmplFieldDefines = ` + +// modulus q +{{- range $i, $w := .Q}} +DATA q<>+{{mul $i 8}}(SB)/8, {{imm $w}} +{{- end}} +GLOBL q<>(SB), (RODATA+NOPTR), ${{mul 8 $.NbWords}} + +// qInv0 q'[0] +DATA qInv0<>(SB)/8, {{$qinv0 := index .QInverse 0}}{{imm $qinv0}} +GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 + +{{- if .ASMVector }} +// Mu +DATA mu<>(SB)/8, {{imm .Mu}} +GLOBL mu<>(SB), (RODATA+NOPTR), $8 +{{- end}} +` + +const tmplReduceDefine = ` +#define REDUCE( {{- range $i := .NbWordsIndexesFull}}ra{{$i}},{{- end}} + {{- range $i := .NbWordsIndexesFull}}rb{{$i}}{{- if ne $.NbWordsLastIndex $i}},{{- end}}{{- end}}) \ + MOVQ ra0, rb0; \ + SUBQ q<>(SB), ra0; \ + {{- range $i := .NbWordsIndexesNoZero}} + MOVQ ra{{$i}}, rb{{$i}}; \ + SBBQ q<>+{{mul $i 8}}(SB), ra{{$i}}; \ + {{- end}} + {{- range $i := .NbWordsIndexesFull}} + CMOVQCS rb{{$i}}, ra{{$i}}; \ + {{- end}} ` -func (f *FFAmd64) GenerateDefines() { +func (f *FFAmd64) GenerateFieldDefines(F *config.FieldConfig) { + tmpl := template.Must(template.New(""). + Funcs(helpers()). + Parse(tmplFieldDefines)) + + // execute template + var buf bytes.Buffer + if err := tmpl.Execute(&buf, F); err != nil { + panic(err) + } + + f.WriteLn(buf.String()) +} + +func (f *FFAmd64) GenerateReduceDefine() { tmpl := template.Must(template.New(""). Funcs(helpers()). - Parse(tmplDefines)) + Parse(tmplReduceDefine)) // execute template var buf bytes.Buffer @@ -108,6 +156,20 @@ func (f *FFAmd64) GenerateDefines() { f.WriteLn(buf.String()) } +func (f *FFAmd64) GenerateDefinesDeprecated(F *config.FieldConfig) { + tmpl := template.Must(template.New(""). + Funcs(helpers()). + Parse(tmplDefinesDeprecated)) + + // execute template + var buf bytes.Buffer + if err := tmpl.Execute(&buf, F); err != nil { + panic(err) + } + + f.WriteLn(buf.String()) +} + func (f *FFAmd64) Mov(i1, i2 interface{}, offsets ...int) { var o1, o2 int if len(offsets) >= 1 { diff --git a/field/generator/asm/amd64/build.go b/field/generator/asm/amd64/build.go index 9a9a65cf7..0f1c8925e 100644 --- a/field/generator/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -18,6 +18,7 @@ package amd64 import ( "fmt" "io" + "path/filepath" "strings" "github.com/consensys/bavard" @@ -28,15 +29,37 @@ import ( const SmallModulus = 6 -func NewFFAmd64(w io.Writer, F *config.FieldConfig) *FFAmd64 { - return &FFAmd64{F, amd64.NewAmd64(w), 0, 0} +func NewFFAmd64(w io.Writer, nbWords int) *FFAmd64 { + F := &FFAmd64{ + amd64.NewAmd64(w), + 0, + 0, + nbWords, + nbWords - 1, + make([]int, nbWords), + make([]int, nbWords-1), + } + + // indexes (template helpers) + for i := 0; i < F.NbWords; i++ { + F.NbWordsIndexesFull[i] = i + if i > 0 { + F.NbWordsIndexesNoZero[i-1] = i + } + } + + return F } type FFAmd64 struct { - *config.FieldConfig + // *config.FieldConfig *amd64.Amd64 - nbElementsOnStack int - maxOnStack int + nbElementsOnStack int + maxOnStack int + NbWords int + NbWordsLastIndex int + NbWordsIndexesFull []int + NbWordsIndexesNoZero []int } func (f *FFAmd64) StackSize(maxNbRegistersNeeded, nbRegistersReserved, minStackSize int) int { @@ -139,17 +162,36 @@ func (f *FFAmd64) mu() string { return "mu<>(SB)" } -// Generate generates assembly code for the base field provided to goff +func GenerateFieldWrapper(w io.Writer, F *config.FieldConfig, asmDir string) error { + // for each field we generate the defines for the modulus and the montgomery constant + f := NewFFAmd64(w, F.NbWords) + + // we add the defines first, then the common asm, then the global variable section + // to enable correct compilations with #include in order. + f.WriteLn("") + for i := 0; i < F.NbWords; i++ { + f.WriteLn(fmt.Sprintf("#define q%d $%#016x", i, F.Q[i])) + } + + toInclude := fmt.Sprintf("element_%dw_amd64.h", F.NbWords) + f.WriteLn(fmt.Sprintf("\n#include \"%s\"\n", filepath.Join(asmDir, toInclude))) + + f.GenerateFieldDefines(F) + + return nil +} + +// GenerateCommonASM generates assembly code for the base field provided to goff // see internal/templates/ops* -func Generate(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) +func GenerateCommonASM(w io.Writer, nbWords int) error { + f := NewFFAmd64(w, nbWords) f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) f.WriteLn("#include \"textflag.h\"") f.WriteLn("#include \"funcdata.h\"") f.WriteLn("") - f.GenerateDefines() + f.GenerateReduceDefine() // reduce f.generateReduce() @@ -170,18 +212,6 @@ func Generate(w io.Writer, F *config.FieldConfig) error { f.generateSumVec() } - return nil -} - -func GenerateMul(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) - f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) - - f.WriteLn("#include \"textflag.h\"") - f.WriteLn("#include \"funcdata.h\"") - f.WriteLn("") - f.GenerateDefines() - // mul f.generateMul(false) @@ -190,21 +220,3 @@ func GenerateMul(w io.Writer, F *config.FieldConfig) error { return nil } - -func GenerateMulADX(w io.Writer, F *config.FieldConfig) error { - f := NewFFAmd64(w, F) - f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) - - f.WriteLn("#include \"textflag.h\"") - f.WriteLn("#include \"funcdata.h\"") - f.WriteLn("") - f.GenerateDefines() - - // mul - f.generateMul(true) - - // from mont - f.generateFromMont(true) - - return nil -} diff --git a/field/generator/asm/amd64/element_butterfly.go b/field/generator/asm/amd64/element_butterfly.go index 2d996754a..948e16ea1 100644 --- a/field/generator/asm/amd64/element_butterfly.go +++ b/field/generator/asm/amd64/element_butterfly.go @@ -14,6 +14,8 @@ package amd64 +import "fmt" + // Butterfly sets // a = a + b // b = a - b @@ -56,7 +58,9 @@ func (f *FFAmd64) generateButterfly() { if f.NbWords >= 5 { // q is on the stack, can't use for CMOVQCC f.Mov(t0, q) // save t0 - f.Mov(f.Q, t0) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("q%d", i), t0[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(a, t0[i]) } @@ -64,7 +68,9 @@ func (f *FFAmd64) generateButterfly() { f.Add(t0, t1) f.Mov(q, t0) // restore t0 } else { - f.Mov(f.Q, q) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("q%d", i), q[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(a, q[i]) } @@ -110,10 +116,11 @@ func (f *FFAmd64) generateButterfly() { noReduce := f.NewLabel("noReduce") f.JCC(noReduce) q := r - f.MOVQ(f.Q[0], q) + f.MOVQ("q0", q) + f.ADDQ(q, t0[0]) for i := 1; i < f.NbWords; i++ { - f.MOVQ(f.Q[i], q) + f.MOVQ(fmt.Sprintf("q%d", i), q) f.ADCQ(q, t0[i]) } f.LABEL(noReduce) diff --git a/field/generator/asm/amd64/element_frommont.go b/field/generator/asm/amd64/element_frommont.go index 79b717dcc..de9d0e3c4 100644 --- a/field/generator/asm/amd64/element_frommont.go +++ b/field/generator/asm/amd64/element_frommont.go @@ -42,8 +42,8 @@ func (f *FFAmd64) generateFromMont(forceADX bool) { f.WriteLn("NO_LOCAL_POINTERS") } f.WriteLn(` - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // when y = 1 we have: // for i=0 to N-1 // t[i] = x[i] diff --git a/field/generator/asm/amd64/element_mul.go b/field/generator/asm/amd64/element_mul.go index df63f1ad4..b5791bc30 100644 --- a/field/generator/asm/amd64/element_mul.go +++ b/field/generator/asm/amd64/element_mul.go @@ -152,19 +152,10 @@ func (f *FFAmd64) generateMul(forceADX bool) { registers := f.FnHeader("mul", stackSize, argSize, reserved...) defer f.AssertCleanStack(stackSize, minStackSize) - f.WriteLn(fmt.Sprintf(` - // the algorithm is described in the %s.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - `, f.ElementName)) + f.WriteLn(` + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 + `) if stackSize > 0 { f.WriteLn("NO_LOCAL_POINTERS") } diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index a4ea06144..e8a380abd 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -125,7 +125,9 @@ func (f *FFAmd64) generateSubVec() { // reduce a f.Comment("reduce (a-b) mod q") f.LabelRegisters("q", q...) - f.Mov(f.Q, q) + for i := 0; i < f.NbWords; i++ { + f.MOVQ(fmt.Sprintf("q%d", i), q[i]) + } for i := 0; i < f.NbWords; i++ { f.CMOVQCC(zero, q[i]) } diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index 5da7da0a3..360e32540 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -51,8 +51,9 @@ type FieldConfig struct { Q []uint64 QInverse []uint64 QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 - Mu uint64 // mu = 2^288 / q for barrett reduction + Mu uint64 // mu = 2^288 / q for 4.5 word barrett reduction ASM bool + ASMVector bool RSquare []uint64 One, Thirteen []uint64 LegendreExponent string // big.Int to base16 string @@ -118,16 +119,6 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) _qInv.Mod(_qInv, _r) F.QInverse = toUint64Slice(_qInv, F.NbWords) - // setting Mu 2^288 / q - if F.NbWords == 4 { - // TODO @gbotrel clean for all modulus. - _mu := big.NewInt(1) - _mu.Lsh(_mu, 288) - _mu.Div(_mu, &bModulus) - muSlice := toUint64Slice(_mu, F.NbWords) - F.Mu = muSlice[0] - } - // Pornin20 inversion correction factors k := 32 // Optimized for 64 bit machines, still works for 32 @@ -271,6 +262,16 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // moduli that meet the condition F.NoCarry // asm code generation for moduli with more than 6 words can be optimized further F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1 + F.ASMVector = F.ASM && F.NbWords == 4 && F.NbBits > 225 + + // setting Mu 2^288 / q + if F.ASMVector { + _mu := big.NewInt(1) + _mu.Lsh(_mu, 288) + _mu.Div(_mu, &bModulus) + muSlice := toUint64Slice(_mu, F.NbWords) + F.Mu = muSlice[0] + } return F, nil } diff --git a/field/generator/generator.go b/field/generator/generator.go index e149ac9ae..28bc54528 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -22,7 +22,7 @@ import ( // // fp, _ = config.NewField("fp", "Element", fpModulus") // generator.GenerateFF(fp, filepath.Join(baseDir, "fp")) -func GenerateFF(F *config.FieldConfig, outputDir string) error { +func GenerateFF(F *config.FieldConfig, outputDir, asmDir string) error { // source file templates sourceFiles := []string{ element.Base, @@ -137,33 +137,7 @@ func GenerateFF(F *config.FieldConfig, outputDir string) error { _, _ = io.WriteString(f, "// +build !purego\n") - if err := amd64.Generate(f, F); err != nil { - _ = f.Close() - return err - } - _ = f.Close() - - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } - } - - { - pathSrc := filepath.Join(outputDir, eName+"_mul_amd64.s") - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } - - _, _ = io.WriteString(f, "// +build !purego\n") - - if err := amd64.GenerateMul(f, F); err != nil { + if err := amd64.GenerateFieldWrapper(f, F, asmDir); err != nil { _ = f.Close() return err } @@ -274,3 +248,30 @@ func shorten(input string) string { } return input } + +func GenerateCommonASM(nbWords int, asmDir string) error { + pathSrc := filepath.Join(asmDir, fmt.Sprintf("element_%dw_amd64.h", nbWords)) + + fmt.Println("generating", pathSrc) + f, err := os.Create(pathSrc) + if err != nil { + return err + } + + if err := amd64.GenerateCommonASM(f, nbWords); err != nil { + _ = f.Close() + return err + } + _ = f.Close() + + // run asmfmt + // run go fmt on whole directory + cmd := exec.Command("asmfmt", "-w", pathSrc) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return err + } + + return nil +} diff --git a/field/generator/generator_test.go b/field/generator/generator_test.go index e107a0e85..aa0b5c364 100644 --- a/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -21,27 +21,32 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" field "github.com/consensys/gnark-crypto/field/generator/config" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) // integration test will create modulus for various field sizes and run tests const rootDir = "integration_test" +const asmDir = "../asm" func TestIntegration(t *testing.T) { + assert := require.New(t) + os.RemoveAll(rootDir) err := os.MkdirAll(rootDir, 0700) defer os.RemoveAll(rootDir) - if err != nil { - t.Fatal(err) - } + assert.NoError(err) var bits []int for i := 64; i <= 448; i += 64 { bits = append(bits, i-3, i-2, i-1, i, i+1) } + bits = append(bits, 224, 225, 226) moduli := make(map[string]string) for _, i := range bits { @@ -78,27 +83,53 @@ func TestIntegration(t *testing.T) { // generate field childDir := filepath.Join(rootDir, elementName) fIntegration, err = field.NewFieldConfig("integration", elementName, modulus, false) - if err != nil { - t.Fatal(elementName, err) - } - if err = GenerateFF(fIntegration, childDir); err != nil { - t.Fatal(elementName, err) - } + assert.NoError(err) + assert.NoError(GenerateFF(fIntegration, childDir, "../../../asm")) } + assert.NoError(GenerateCommonASM(2, asmDir)) + assert.NoError(GenerateCommonASM(3, asmDir)) + assert.NoError(GenerateCommonASM(7, asmDir)) + assert.NoError(GenerateCommonASM(8, asmDir)) + // run go test wd, err := os.Getwd() if err != nil { t.Fatal(err) } - packageDir := filepath.Join(wd, rootDir) + string(filepath.Separator) + "..." - cmd := exec.Command("go", "test", packageDir) - if err := cmd.Run(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - t.Fatal(string(exitErr.Stderr)) - } else { - t.Fatal(err) + packageDir := filepath.Join(wd, rootDir) // + string(filepath.Separator) + "..." + + // list all subdirectories in package dir + var subDirs []string + err = filepath.Walk(packageDir, func(path string, info os.FileInfo, err error) error { + if info.IsDir() && path != packageDir { + subDirs = append(subDirs, path) } + return nil + }) + if err != nil { + t.Fatal(err) + } + + errGroup := errgroup.Group{} + + for _, subDir := range subDirs { + // run go test in parallel + errGroup.Go(func() error { + cmd := exec.Command("go", "test") + cmd.Dir = subDir + var stdouterr strings.Builder + cmd.Stdout = &stdouterr + cmd.Stderr = &stdouterr + if err := cmd.Run(); err != nil { + return fmt.Errorf("go test failed, output:\n%s\n%s", stdouterr.String(), err) + } + return nil + }) + } + + if err := errGroup.Wait(); err != nil { + t.Fatal(err) } } diff --git a/field/generator/internal/templates/element/asm.go b/field/generator/internal/templates/element/asm.go index 5a73aa53b..9ddfecc25 100644 --- a/field/generator/internal/templates/element/asm.go +++ b/field/generator/internal/templates/element/asm.go @@ -7,7 +7,7 @@ import "golang.org/x/sys/cpu" var ( supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 _ = supportAdx - {{- if eq .NbWords 4}} + {{- if .ASMVector}} supportAvx512 = cpu.X86.HasAVX512 && cpu.X86.HasAVX512DQ _ = supportAvx512 {{- end}} @@ -23,7 +23,7 @@ const AsmNoAdx = ` var ( supportAdx = false _ = supportAdx - {{- if eq .NbWords 4}} + {{- if .ASMVector}} supportAvx512 = false _ = supportAvx512 {{- end}} diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go index 070d07ec8..6fb554c3b 100644 --- a/field/generator/internal/templates/element/mul_cios.go +++ b/field/generator/internal/templates/element/mul_cios.go @@ -2,6 +2,9 @@ package element // MulCIOS text book CIOS works for all modulus. // +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 +// // There are couple of variations to the multiplication (and squaring) algorithms. // // All versions are derived from the Montgomery CIOS algorithm: see @@ -126,49 +129,7 @@ const MulCIOS = ` const MulDoc = ` {{define "mul_doc noCarry"}} -// Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis -// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf -// -// The algorithm: -// -// for i=0 to N-1 -// C := 0 -// for j=0 to N-1 -// (C,t[j]) := t[j] + x[j]*y[i] + C -// (t[N+1],t[N]) := t[N] + C -// -// C := 0 -// m := t[0]*q'[0] mod D -// (C,_) := t[0] + m*q[0] -// for j=1 to N-1 -// (C,t[j-1]) := t[j] + m*q[j] + C -// -// (C,t[N-1]) := t[N] + C -// t[N] := t[N+1] + C -// -// → N is the number of machine words needed to store the modulus q -// → D is the word size. For example, on a 64-bit architecture D is 2 64 -// → x[i], y[i], q[i] is the ith word of the numbers x,y,q -// → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. -// → t is a temporary array of size N+2 -// → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number -{{- if .noCarry}} -// -// As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: -// (also described in https://eprint.iacr.org/2022/1400.pdf annex) -// -// for i=0 to N-1 -// (A,t[0]) := t[0] + x[0]*y[i] -// m := t[0]*q'[0] mod W -// C,_ := t[0] + m*q[0] -// for j=1 to N-1 -// (A,t[j]) := t[j] + x[j]*y[i] + A -// (C,t[j-1]) := t[j] + m*q[j] + C -// -// t[N-1] = C + A -// -// This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit -// of the modulus is zero (and not all of the remaining bits are set). -{{- end}} +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 {{ end }} ` diff --git a/field/generator/internal/templates/element/mul_nocarry.go b/field/generator/internal/templates/element/mul_nocarry.go index 0ec89f7a8..14740fd4a 100644 --- a/field/generator/internal/templates/element/mul_nocarry.go +++ b/field/generator/internal/templates/element/mul_nocarry.go @@ -1,6 +1,8 @@ package element -// MulNoCarry see https://eprint.iacr.org/2022/1400.pdf annex for more info on the algorithm +// MulNoCarry +// Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" +// by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 // Note that these templates are optimized for arm64 target, since x86 benefits from assembly impl. const MulNoCarry = ` {{ define "mul_nocarry" }} diff --git a/field/generator/internal/templates/element/ops_asm.go b/field/generator/internal/templates/element/ops_asm.go index c068d7125..7892b7bc2 100644 --- a/field/generator/internal/templates/element/ops_asm.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -29,7 +29,7 @@ func reduce(res *{{.ElementName}}) //go:noescape func Butterfly(a, b *{{.ElementName}}) -{{- if eq .NbWords 4}} +{{- if .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go index d3340a8ff..0c6b4f2ce 100644 --- a/field/generator/internal/templates/element/ops_purego.go +++ b/field/generator/internal/templates/element/ops_purego.go @@ -50,7 +50,7 @@ func reduce(z *{{.ElementName}}) { _reduceGeneric(z) } -{{- if eq .NbWords 4}} +{{- if .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { diff --git a/field/generator/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go index f31528489..cbca9710a 100644 --- a/field/generator/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -653,8 +653,6 @@ func Test{{toTitle .ElementName}}BitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - - } @@ -740,11 +738,17 @@ func Test{{toTitle .ElementName}}VecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 {{.ElementName}} - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := q{{.ElementName}} + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 6db71d7cd..b455a7940 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -193,7 +193,7 @@ func (vector Vector) Swap(i, j int) { {{/* For 4 elements, we have a special assembly path and copy this in ops_pure.go */}} -{{- if ne .NbWords 4}} +{{- if not .ASMVector}} // Add adds two vectors element-wise and stores the result in self. // It panics if the vectors don't have the same length. func (vector *Vector) Add(a, b Vector) { diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index 15ed94fc0..8217847d7 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -71,7 +71,8 @@ func cmdGenerate(cmd *cobra.Command, args []string) { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } - if err := generator.GenerateFF(F, fOutputDir); err != nil { + // TODO @gbotrel this is broken with new asm dir. + if err := generator.GenerateFF(F, fOutputDir, "FIXME"); err != nil { fmt.Printf("\n%s\n", err.Error()) os.Exit(-1) } diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index 3afe5c447..8dd4d6991 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -412,32 +412,8 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { // and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // Algorithm 2 of "Faster Montgomery Multiplication and Multi-Scalar-Multiplication for SNARKS" + // by Y. El Housni and G. Botrel https://doi.org/10.46586/tches.v2023.i3.504-521 var t [2]uint64 var D uint64 diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 1746bde1c..5b7baa646 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -582,7 +582,6 @@ func TestElementBitLen(t *testing.T) { )) properties.TestingRun(t, gopter.ConsoleReporter(false)) - } func TestElementButterflies(t *testing.T) { @@ -664,11 +663,17 @@ func TestElementVecOps(t *testing.T) { // set m to max values element // it's not really q-1 (since we have montgomery representation) // but it's the "largest" legal value - qMinus1 := new(big.Int).Sub(Modulus(), big.NewInt(1)) - - var eQMinus1 Element - for i, v := range qMinus1.Bits() { - eQMinus1[i] = uint64(v) + eQMinus1 := qElement + if eQMinus1[0] != 0 { + eQMinus1[0]-- + } else { + eQMinus1[0] = ^uint64(0) + for i := 1; i < len(eQMinus1); i++ { + if eQMinus1[i] != 0 { + eQMinus1[i]-- + break + } + } } for i := 0; i < N; i++ { diff --git a/field/goldilocks/internal/main.go b/field/goldilocks/internal/main.go index 4f5bacd3f..235bb7982 100644 --- a/field/goldilocks/internal/main.go +++ b/field/goldilocks/internal/main.go @@ -14,7 +14,7 @@ func main() { if err != nil { panic(err) } - if err := generator.GenerateFF(goldilocks, "../"); err != nil { + if err := generator.GenerateFF(goldilocks, "../", ""); err != nil { panic(err) } fmt.Println("successfully generated goldilocks field") diff --git a/go.mod b/go.mod index dd93d660a..0fc43903d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.26.0 + golang.org/x/sync v0.1.0 golang.org/x/sys v0.24.0 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 3d91ce1bf..09f6b3b64 100644 --- a/go.sum +++ b/go.sum @@ -384,6 +384,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/generator/main.go b/internal/generator/main.go index 389f96c2e..aa8d71eee 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -45,26 +45,39 @@ var bgen = bavard.NewBatchGenerator(copyrightHolder, copyrightYear, "consensys/g func main() { var wg sync.WaitGroup + // generate common assembly files depending on field number of words + m := make(map[int]bool) + for i, conf := range config.Curves { + var err error + // generate base field + conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) + assertNoError(err) + + conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) + assertNoError(err) + + m[conf.Fr.NbWords] = true + m[conf.Fp.NbWords] = true + + config.Curves[i] = conf + } + asmDir := filepath.Join(baseDir, "field", "asm") + for nbWords := range m { + assertNoError(generator.GenerateCommonASM(nbWords, asmDir)) + } + for _, conf := range config.Curves { wg.Add(1) // for each curve, generate the needed files go func(conf config.Curve) { defer wg.Done() - var err error curveDir := filepath.Join(baseDir, "ecc", conf.Name) - // generate base field - conf.Fp, err = field.NewFieldConfig("fp", "Element", conf.FpModulus, true) - assertNoError(err) - - conf.Fr, err = field.NewFieldConfig("fr", "Element", conf.FrModulus, !conf.Equal(config.STARK_CURVE)) - assertNoError(err) - conf.FpUnusedBits = 64 - (conf.Fp.NbBits % 64) - assertNoError(generator.GenerateFF(conf.Fr, filepath.Join(curveDir, "fr"))) - assertNoError(generator.GenerateFF(conf.Fp, filepath.Join(curveDir, "fp"))) + assertNoError(generator.GenerateFF(conf.Fr, filepath.Join(curveDir, "fr"), filepath.Join("..", asmDir))) + assertNoError(generator.GenerateFF(conf.Fp, filepath.Join(curveDir, "fp"), filepath.Join("..", asmDir))) // generate ecdsa assertNoError(ecdsa.Generate(conf, curveDir, bgen)) diff --git a/internal/generator/tower/asm/amd64/e2.go b/internal/generator/tower/asm/amd64/e2.go index 5fb3ea969..d255d5a69 100644 --- a/internal/generator/tower/asm/amd64/e2.go +++ b/internal/generator/tower/asm/amd64/e2.go @@ -35,7 +35,7 @@ type Fq2Amd64 struct { // NewFq2Amd64 ... func NewFq2Amd64(w io.Writer, F *field.FieldConfig, config config.Curve) *Fq2Amd64 { return &Fq2Amd64{ - amd64.NewFFAmd64(w, F), + amd64.NewFFAmd64(w, F.NbWords), config, w, F, @@ -49,7 +49,7 @@ func (fq2 *Fq2Amd64) Generate(forceADXCheck bool) error { fq2.WriteLn("#include \"textflag.h\"") fq2.WriteLn("#include \"funcdata.h\"") - fq2.GenerateDefines() + fq2.GenerateDefinesDeprecated(fq2.F) if fq2.config.Equal(config.BN254) { fq2.generateMulDefine() } @@ -174,7 +174,7 @@ func (fq2 *Fq2Amd64) generateNegE2() { // z = x - q for i := 0; i < fq2.NbWords; i++ { - fq2.MOVQ(fq2.Q[i], q) + fq2.MOVQ(fq2.F.Q[i], q) if i == 0 { fq2.SUBQ(t[i], q) } else { @@ -208,7 +208,7 @@ func (fq2 *Fq2Amd64) generateNegE2() { // z = x - q for i := 0; i < fq2.NbWords; i++ { - fq2.MOVQ(fq2.Q[i], q) + fq2.MOVQ(fq2.F.Q[i], q) if i == 0 { fq2.SUBQ(t[i], q) } else { @@ -272,7 +272,7 @@ func (fq2 *Fq2Amd64) modReduceAfterSub(registers *ramd64.Registers, zero ramd64. } func (fq2 *Fq2Amd64) modReduceAfterSubScratch(zero ramd64.Register, t, scratch []ramd64.Register) { - fq2.Mov(fq2.Q, scratch) + fq2.Mov(fq2.F.Q, scratch) for i := 0; i < fq2.NbWords; i++ { fq2.CMOVQCC(zero, scratch[i]) } diff --git a/internal/generator/tower/asm/amd64/e2_bn254.go b/internal/generator/tower/asm/amd64/e2_bn254.go index 821338109..a25145bf7 100644 --- a/internal/generator/tower/asm/amd64/e2_bn254.go +++ b/internal/generator/tower/asm/amd64/e2_bn254.go @@ -306,7 +306,7 @@ func (fq2 *Fq2Amd64) generateMulDefine() { } wd := writerDefine{fq2.w} - tw := gamd64.NewFFAmd64(&wd, fq2.F) + tw := gamd64.NewFFAmd64(&wd, fq2.F.NbWords) _, _ = io.WriteString(fq2.w, "// this code is generated and identical to fp.Mul(...)\n") _, _ = io.WriteString(fq2.w, "#define MUL() \\ \n") From b27149d815818e9100a21546909630fd5858af21 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 24 Sep 2024 10:00:43 -0500 Subject: [PATCH 14/14] doc: add reference for reduction algorithm --- field/asm/element_4w_amd64.h | 1 + field/generator/asm/amd64/element_vec.go | 1 + 2 files changed, 2 insertions(+) diff --git a/field/asm/element_4w_amd64.h b/field/asm/element_4w_amd64.h index a1fa17f7b..51935dae6 100644 --- a/field/asm/element_4w_amd64.h +++ b/field/asm/element_4w_amd64.h @@ -769,6 +769,7 @@ TEXT ·sumVec(SB), NOSPLIT, $0-24 // r[3] -> R11 // r[4] -> R12 // reduce using single-word Barrett + // see see Handbook of Applied Cryptography, Algorithm 14.42. // mu=2^288 / q -> SI MOVQ mu<>(SB), SI MOVQ R11, AX diff --git a/field/generator/asm/amd64/element_vec.go b/field/generator/asm/amd64/element_vec.go index e8a380abd..b32ab8d9f 100644 --- a/field/generator/asm/amd64/element_vec.go +++ b/field/generator/asm/amd64/element_vec.go @@ -476,6 +476,7 @@ func (f *FFAmd64) generateSumVec() { mu := f.Pop(®isters) f.Comment("reduce using single-word Barrett") + f.Comment("see see Handbook of Applied Cryptography, Algorithm 14.42.") f.LabelRegisters("mu=2^288 / q", mu) f.MOVQ(f.mu(), mu) f.MOVQ(r3, amd64.AX)