Skip to content

Commit

Permalink
[circuits/float]: further improved bootstrapper API
Browse files Browse the repository at this point in the history
  • Loading branch information
Pro7ech committed Sep 14, 2023
1 parent 91f4c80 commit f11c30c
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 203 deletions.
123 changes: 9 additions & 114 deletions circuits/float/bootstrapper/bootstrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,122 +27,14 @@ type Bootstrapper struct {
xPow2N2 []ring.Poly
xPow2InvN2 []ring.Poly

evk BootstrappingKeys
evk *BootstrappingKeys
}

type BootstrappingKeys struct {
EvkN1ToN2 *rlwe.EvaluationKey
EvkN2ToN1 *rlwe.EvaluationKey
EvkRealToCmplx *rlwe.EvaluationKey
EvkCmplxToReal *rlwe.EvaluationKey
EvkBootstrapping *bootstrapping.EvaluationKeySet
}

func (b BootstrappingKeys) BinarySize() (dLen int) {
if b.EvkN1ToN2 != nil {
dLen += b.EvkN1ToN2.BinarySize()
}

if b.EvkN2ToN1 != nil {
dLen += b.EvkN2ToN1.BinarySize()
}

if b.EvkRealToCmplx != nil {
dLen += b.EvkRealToCmplx.BinarySize()
}

if b.EvkCmplxToReal != nil {
dLen += b.EvkCmplxToReal.BinarySize()
}

if b.EvkBootstrapping != nil {
dLen += b.EvkBootstrapping.BinarySize()
}

return
}

func GenBootstrappingKeys(paramsN1 ckks.Parameters, btpParamsN2 bootstrapping.Parameters, skN1 *rlwe.SecretKey) (BootstrappingKeys, error) {

var EvkN1ToN2, EvkN2ToN1 *rlwe.EvaluationKey
var EvkRealToCmplx *rlwe.EvaluationKey
var EvkCmplxToReal *rlwe.EvaluationKey
paramsN2 := btpParamsN2.Parameters

// Checks that the maximum level of paramsN1 is equal to the remaining level after the bootstrapping of paramsN2
if paramsN2.MaxLevel()-btpParamsN2.SlotsToCoeffsParameters.Depth(true)-btpParamsN2.Mod1ParametersLiteral.Depth()-btpParamsN2.CoeffsToSlotsParameters.Depth(true) < paramsN1.MaxLevel() {
return BootstrappingKeys{}, fmt.Errorf("cannot GenBootstrappingKeys: bootstrapping depth is too large, level after bootstrapping is smaller than paramsN1.MaxLevel()")
}

// Checks that the overlapping primes between paramsN1 and paramsN2 are the same, i.e.
// pN1: q0, q1, q2, ..., qL
// pN2: q0, q1, q2, ..., qL, [bootstrapping primes]
QN1 := paramsN1.Q()
QN2 := paramsN2.Q()

for i := range QN1 {
if QN1[i] != QN2[i] {
return BootstrappingKeys{}, fmt.Errorf("cannot GenBootstrappingKeys: paramsN1.Q() is not a subset of paramsN2.Q()")
}
}

kgen := ckks.NewKeyGenerator(paramsN2)

// Ephemeral secret-key used to generate the evaluation keys.
skN2 := rlwe.NewSecretKey(paramsN2)
buff := paramsN2.RingQ().NewPoly()
ringQ := paramsN2.RingQ()
ringP := paramsN2.RingP()

switch paramsN1.RingType() {
// In this case we need need generate the bridge switching keys between the two rings
case ring.ConjugateInvariant:

if paramsN1.LogN() != paramsN2.LogN()-1 {
return BootstrappingKeys{}, fmt.Errorf("cannot GenBootstrappingKeys: if paramsN1.RingType() == ring.ConjugateInvariant then must ensure that paramsN1.LogN()+1 == paramsN2.LogN()-1")
}

// R[X+X^-1]/(X^N +1) -> R[X]/(X^2N + 1)
ringQ.AtLevel(skN1.LevelQ()).UnfoldConjugateInvariantToStandard(skN1.Value.Q, skN2.Value.Q)

// Extends basis Q0 -> QL
rlwe.ExtendBasisSmallNormAndCenterNTTMontgomery(ringQ, ringQ, skN2.Value.Q, buff, skN2.Value.Q)

// Extends basis Q0 -> P
rlwe.ExtendBasisSmallNormAndCenterNTTMontgomery(ringQ, ringP, skN2.Value.Q, buff, skN2.Value.P)

EvkCmplxToReal, EvkRealToCmplx = kgen.GenEvaluationKeysForRingSwapNew(skN2, skN1)

// Only regular key-switching is required in this case
case ring.Standard:

// Maps the smaller key to the largest with Y = X^{N/n}.
ring.MapSmallDimensionToLargerDimensionNTT(skN1.Value.Q, skN2.Value.Q)

// Extends basis Q0 -> QL
rlwe.ExtendBasisSmallNormAndCenterNTTMontgomery(ringQ, ringQ, skN2.Value.Q, buff, skN2.Value.Q)

// Extends basis Q0 -> P
rlwe.ExtendBasisSmallNormAndCenterNTTMontgomery(ringQ, ringP, skN2.Value.Q, buff, skN2.Value.P)

EvkN1ToN2 = kgen.GenEvaluationKeyNew(skN1, skN2)
EvkN2ToN1 = kgen.GenEvaluationKeyNew(skN2, skN1)
}

return BootstrappingKeys{
EvkN1ToN2: EvkN1ToN2,
EvkN2ToN1: EvkN2ToN1,
EvkRealToCmplx: EvkRealToCmplx,
EvkCmplxToReal: EvkCmplxToReal,
EvkBootstrapping: btpParamsN2.GenEvaluationKeySetNew(skN2),
}, nil
}

func NewBootstrapper(paramsN1 ckks.Parameters, btpParamsN2 bootstrapping.Parameters, evk BootstrappingKeys) (rlwe.Bootstrapper, error) {
func NewBootstrapper(paramsN1 ckks.Parameters, btpParamsN2 Parameters, evk *BootstrappingKeys) (*Bootstrapper, error) {

b := &Bootstrapper{}

paramsN2 := btpParamsN2.Parameters
paramsN2 := btpParamsN2.Parameters.Parameters

switch paramsN1.RingType() {
case ring.Standard:
Expand All @@ -165,7 +57,7 @@ func NewBootstrapper(paramsN1 ckks.Parameters, btpParamsN2 bootstrapping.Paramet

b.paramsN1 = paramsN1
b.paramsN2 = paramsN2
b.btpParamsN2 = btpParamsN2
b.btpParamsN2 = btpParamsN2.Parameters
b.evk = evk

b.xPow2N2 = rlwe.GenXPow2(b.paramsN2.RingQ().AtLevel(0), b.paramsN2.LogN(), false)
Expand All @@ -180,7 +72,7 @@ func NewBootstrapper(paramsN1 ckks.Parameters, btpParamsN2 bootstrapping.Paramet
}

var err error
if b.bootstrapper, err = bootstrapping.NewBootstrapper(btpParamsN2, evk.EvkBootstrapping); err != nil {
if b.bootstrapper, err = bootstrapping.NewBootstrapper(btpParamsN2.Parameters, evk.EvkBootstrapping); err != nil {
return nil, err
}

Expand All @@ -202,7 +94,10 @@ func (b Bootstrapper) MinimumInputLevel() int {
func (b Bootstrapper) Bootstrap(ct *rlwe.Ciphertext) (*rlwe.Ciphertext, error) {
cts := []*rlwe.Ciphertext{ct}
cts, err := b.BootstrapMany(cts)
return cts[0], err
if err != nil {
return nil, err
}
return cts[0], nil
}

func (b Bootstrapper) BootstrapMany(cts []*rlwe.Ciphertext) ([]*rlwe.Ciphertext, error) {
Expand Down
102 changes: 50 additions & 52 deletions circuits/float/bootstrapper/bootstrapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tuneinsight/lattigo/v4/circuits/float/bootstrapper/bootstrapping"
"github.com/tuneinsight/lattigo/v4/ckks"
"github.com/tuneinsight/lattigo/v4/ring"
"github.com/tuneinsight/lattigo/v4/rlwe"
Expand All @@ -17,26 +16,33 @@ import (
var flagLongTest = flag.Bool("long", false, "run the long test suite (all parameters + secure bootstrapping). Overrides -short and requires -timeout=0.")
var printPrecisionStats = flag.Bool("print-precision", false, "print precision stats")

var testPrec45 = ckks.ParametersLiteral{
LogN: 10,
LogQ: []int{60, 40},
LogP: []int{61},
LogDefaultScale: 40,
}

func TestBootstrapping(t *testing.T) {

// Check that the bootstrapper complies to the rlwe.Bootstrapper interface
var _ rlwe.Bootstrapper = (*Bootstrapper)(nil)

t.Run("BootstrapingWithoutRingDegreeSwitch", func(t *testing.T) {

paramSet := bootstrapping.DefaultParametersSparse[0]
paramSet.SchemeParams.LogQ = paramSet.SchemeParams.LogQ[:utils.Min(2, len(paramSet.SchemeParams.LogQ))]
schemeParamsLit := testPrec45
btpParamsLit := ParametersLiteral{}

if !*flagLongTest {
paramSet.SchemeParams.LogN = 13
if *flagLongTest {
schemeParamsLit.LogN = 16
}

params, err := ckks.NewParametersFromLiteral(paramSet.SchemeParams)
params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

paramSet.BootstrappingParams.LogN = utils.Pointy(params.LogN())
btpParamsLit.LogN = utils.Pointy(params.LogN())

btpParams, err := bootstrapping.NewParametersFromLiteral(params, paramSet.BootstrappingParams)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
Expand All @@ -50,17 +56,15 @@ func TestBootstrapping(t *testing.T) {

t.Logf("ParamsN2: LogN=%d/LogSlots=%d/LogQP=%f", params.LogN(), params.LogMaxSlots(), params.LogQP())

sk := ckks.NewKeyGenerator(btpParams.Parameters).GenSecretKeyNew()
sk := ckks.NewKeyGenerator(btpParams.Parameters.Parameters).GenSecretKeyNew()

t.Log("Generating Bootstrapping Keys")
btpKeys, err := GenBootstrappingKeys(params, btpParams, sk)
btpKeys, err := btpParams.GenBootstrappingKeys(params, sk)
require.NoError(t, err)

bootstrapperInterface, err := NewBootstrapper(params, btpParams, btpKeys)
bootstrapper, err := NewBootstrapper(params, btpParams, btpKeys)
require.NoError(t, err)

bootstrapper := bootstrapperInterface.(*Bootstrapper)

ecd := ckks.NewEncoder(params)
enc := ckks.NewEncryptor(params, sk)
dec := ckks.NewDecryptor(params, sk)
Expand Down Expand Up @@ -102,22 +106,22 @@ func TestBootstrapping(t *testing.T) {

t.Run("BootstrappingWithRingDegreeSwitch", func(t *testing.T) {

paramSet := bootstrapping.DefaultParametersSparse[0]
paramSet.SchemeParams.LogQ = paramSet.SchemeParams.LogQ[:utils.Min(2, len(paramSet.SchemeParams.LogQ))]
schemeParamsLit := testPrec45
btpParamsLit := ParametersLiteral{}

if !*flagLongTest {
paramSet.SchemeParams.LogN = 13
paramSet.SchemeParams.LogNthRoot = paramSet.SchemeParams.LogN + 1
if *flagLongTest {
schemeParamsLit.LogN = 16
}

paramSet.SchemeParams.LogN--
schemeParamsLit.LogNthRoot = schemeParamsLit.LogN + 1
schemeParamsLit.LogN--

params, err := ckks.NewParametersFromLiteral(paramSet.SchemeParams)
params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

paramSet.BootstrappingParams.LogN = utils.Pointy(params.LogN() + 1)
btpParamsLit.LogN = utils.Pointy(params.LogN() + 1)

btpParams, err := bootstrapping.NewParametersFromLiteral(params, paramSet.BootstrappingParams)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
Expand All @@ -135,7 +139,7 @@ func TestBootstrapping(t *testing.T) {
sk := ckks.NewKeyGenerator(params).GenSecretKeyNew()

t.Log("Generating Bootstrapping Keys")
btpKeys, err := GenBootstrappingKeys(params, btpParams, sk)
btpKeys, err := btpParams.GenBootstrappingKeys(params, sk)
require.Nil(t, err)

bootstrapper, err := NewBootstrapper(params, btpParams, btpKeys)
Expand Down Expand Up @@ -186,22 +190,21 @@ func TestBootstrapping(t *testing.T) {

t.Run("BootstrappingPackedWithRingDegreeSwitch", func(t *testing.T) {

paramSet := bootstrapping.DefaultParametersSparse[0]
paramSet.SchemeParams.LogQ = paramSet.SchemeParams.LogQ[:utils.Min(2, len(paramSet.SchemeParams.LogQ))]
schemeParamsLit := testPrec45
btpParamsLit := ParametersLiteral{}

if !*flagLongTest {
paramSet.SchemeParams.LogN = 13
paramSet.SchemeParams.LogNthRoot = paramSet.SchemeParams.LogN + 1
if *flagLongTest {
schemeParamsLit.LogN = 16
}

paramSet.SchemeParams.LogN -= 5
btpParamsLit.LogN = utils.Pointy(schemeParamsLit.LogN)
schemeParamsLit.LogNthRoot = schemeParamsLit.LogN + 1
schemeParamsLit.LogN -= 3

params, err := ckks.NewParametersFromLiteral(paramSet.SchemeParams)
params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

paramSet.BootstrappingParams.LogN = utils.Pointy(params.LogN() + 5)

btpParams, err := bootstrapping.NewParametersFromLiteral(params, paramSet.BootstrappingParams)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
Expand All @@ -219,7 +222,7 @@ func TestBootstrapping(t *testing.T) {
sk := ckks.NewKeyGenerator(params).GenSecretKeyNew()

t.Log("Generating Bootstrapping Keys")
btpKeys, err := GenBootstrappingKeys(params, btpParams, sk)
btpKeys, err := btpParams.GenBootstrappingKeys(params, sk)
require.Nil(t, err)

bootstrapper, err := NewBootstrapper(params, btpParams, btpKeys)
Expand All @@ -243,7 +246,7 @@ func TestBootstrapping(t *testing.T) {

pt := ckks.NewPlaintext(params, 0)

cts := make([]*rlwe.Ciphertext, 17)
cts := make([]*rlwe.Ciphertext, 7)
for i := range cts {

require.NoError(t, ecd.Encode(utils.RotateSlice(values, i), pt))
Expand All @@ -269,29 +272,26 @@ func TestBootstrapping(t *testing.T) {

t.Run("BootstrappingWithRingTypeSwitch", func(t *testing.T) {

paramSet := bootstrapping.DefaultParametersSparse[0]
paramSet.SchemeParams.LogQ = paramSet.SchemeParams.LogQ[:utils.Min(2, len(paramSet.SchemeParams.LogQ))]
paramSet.SchemeParams.RingType = ring.ConjugateInvariant
schemeParamsLit := testPrec45
schemeParamsLit.RingType = ring.ConjugateInvariant
btpParamsLit := ParametersLiteral{}

if !*flagLongTest {
paramSet.SchemeParams.LogN = 13
if *flagLongTest {
schemeParamsLit.LogN = 16
}

paramSet.SchemeParams.LogN--
btpParamsLit.LogN = utils.Pointy(schemeParamsLit.LogN)
schemeParamsLit.LogNthRoot = schemeParamsLit.LogN + 1
schemeParamsLit.LogN--

params, err := ckks.NewParametersFromLiteral(paramSet.SchemeParams)
params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

paramSet.BootstrappingParams.LogN = utils.Pointy(params.LogN() + 1)

btpParams, err := bootstrapping.NewParametersFromLiteral(params, paramSet.BootstrappingParams)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
if !*flagLongTest {
btpParams.SlotsToCoeffsParameters.LogSlots = btpParams.LogN() - 1
btpParams.CoeffsToSlotsParameters.LogSlots = btpParams.LogN() - 1

// Corrects the message ratio to take into account the smaller number of slots and keep the same precision
btpParams.Mod1ParametersLiteral.LogMessageRatio += 16 - params.LogN()
}
Expand All @@ -302,14 +302,12 @@ func TestBootstrapping(t *testing.T) {
sk := ckks.NewKeyGenerator(params).GenSecretKeyNew()

t.Log("Generating Bootstrapping Keys")
btpKeys, err := GenBootstrappingKeys(params, btpParams, sk)
btpKeys, err := btpParams.GenBootstrappingKeys(params, sk)
require.Nil(t, err)

bootstrapperInterface, err := NewBootstrapper(params, btpParams, btpKeys)
bootstrapper, err := NewBootstrapper(params, btpParams, btpKeys)
require.Nil(t, err)

bootstrapper := bootstrapperInterface.(*Bootstrapper)

ecd := ckks.NewEncoder(params)
enc := ckks.NewEncryptor(params, sk)
dec := ckks.NewDecryptor(params, sk)
Expand Down
Loading

0 comments on commit f11c30c

Please sign in to comment.