From 529dcd537d92abdadfbc019af060cb8c19f0bd75 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Tue, 19 Nov 2024 15:30:17 +0100 Subject: [PATCH] fix innersum bgv --- schemes/bgv/evaluator.go | 37 +++++++++++++++++++++++++++++++++++++ schemes/bgv/params.go | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/schemes/bgv/evaluator.go b/schemes/bgv/evaluator.go index 554f7c5f3..fb94c00c0 100644 --- a/schemes/bgv/evaluator.go +++ b/schemes/bgv/evaluator.go @@ -1505,6 +1505,43 @@ func (eval Evaluator) RotateHoistedLazyNew(level int, rotations []int, op0 *rlwe return } +// InnerSum computes the inner sum of the underlying slots (see [rlwe.Evaluator.InnerSum]). +// NB: in the slot encoding of BGV/BFV, the underlying N slots are arranged as 2 rows of N/2 slots. +// If n*batchSize < N/2, InnerSum computes the [rlwe.Evaluator.InnerSum] of each row separately. +// If n*batchSize = N, InnerSum computes the [rlwe.Evaluator.InnerSum] on the concatenation of both rows. +// NOTE: In this case, InnerSum performs an addition and a [Evaluator.RotateRowsNew] on top +// Otherwise, InnerSum returns an error. +func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) { + N := eval.parameters.N() + halfN := N >> 1 + l := n * batchSize + + if l > halfN { + if l != N { + return fmt.Errorf("innersum: n*batchSize=%d > N/2=%d and n*batchSize != N=%d", l, halfN, N) + } + + if err = eval.Evaluator.InnerSum(ctIn, batchSize, n/2, opOut); err != nil { + return + } + + var ctRot *rlwe.Ciphertext + ctRot, err = eval.RotateRowsNew(opOut) + if err != nil { + return + } + + if err = eval.Add(opOut, ctRot, opOut); err != nil { + return + } + + return + } + + err = eval.Evaluator.InnerSum(ctIn, batchSize, n, opOut) + return +} + // MatchScalesAndLevel updates the both input ciphertexts to ensures that their scale matches. // To do so it computes t0 * a = opOut * b such that: // - ct0.Scale * a = opOut.Scale: make the scales match. diff --git a/schemes/bgv/params.go b/schemes/bgv/params.go index 066263837..8ae690de7 100644 --- a/schemes/bgv/params.go +++ b/schemes/bgv/params.go @@ -257,7 +257,7 @@ func (p Parameters) GaloisElementForRowRotation() uint64 { // InnerSum operation with parameters batch and n. func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) { galEls = rlwe.GaloisElementsForInnerSum(p, batch, n) - if n > p.N()>>1 { + if n*batch > p.N()>>1 { galEls = append(galEls, p.GaloisElementForRowRotation()) } return