Skip to content

Commit

Permalink
make more client fns be stateless
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjit-bhat committed Nov 15, 2024
1 parent 6c229f8 commit 1befdd6
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions kt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,44 +59,44 @@ func (c *Client) checkDig(dig *SigDig) *clientErr {

// checkVrfProof errors on fail.
// TODO: if VRF pubkey is bad, does VRF.Verify still mean something?
func (c *Client) checkVrf(uid uint64, ver uint64, label []byte, proof []byte) bool {
func checkVrf(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, label []byte, proof []byte) bool {
pre := &MapLabelPre{Uid: uid, Ver: ver}
preByt := MapLabelPreEncode(make([]byte, 0), pre)
return !c.servVrfPk.Verify(preByt, label, proof)
return !pk.Verify(preByt, label, proof)
}

// checkMemb errors on fail.
func (c *Client) checkMemb(uid uint64, ver uint64, dig []byte, memb *Memb) bool {
if c.checkVrf(uid, ver, memb.Label, memb.VrfProof) {
func checkMemb(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, memb *Memb) bool {
if checkVrf(pk, uid, ver, memb.Label, memb.VrfProof) {
return true
}
mapVal := compMapVal(memb.EpochAdded, memb.CommOpen)
return merkle.CheckProof(true, memb.MerkProof, memb.Label, mapVal, dig)
}

// checkMembHide errors on fail.
func (c *Client) checkMembHide(uid uint64, ver uint64, dig []byte, memb *MembHide) bool {
if c.checkVrf(uid, ver, memb.Label, memb.VrfProof) {
func checkMembHide(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, memb *MembHide) bool {
if checkVrf(pk, uid, ver, memb.Label, memb.VrfProof) {
return true
}
return merkle.CheckProof(true, memb.MerkProof, memb.Label, memb.MapVal, dig)
}

// checkHist errors on fail.
func (c *Client) checkHist(uid uint64, dig []byte, membs []*MembHide) bool {
func checkHist(pk *cryptoffi.VrfPublicKey, uid uint64, dig []byte, membs []*MembHide) bool {
var err0 bool
for ver0, memb := range membs {
ver := uint64(ver0)
if c.checkMembHide(uid, ver, dig, memb) {
if checkMembHide(pk, uid, ver, dig, memb) {
err0 = true
}
}
return err0
}

// checkNonMemb errors on fail.
func (c *Client) checkNonMemb(uid uint64, ver uint64, dig []byte, nonMemb *NonMemb) bool {
if c.checkVrf(uid, ver, nonMemb.Label, nonMemb.VrfProof) {
func checkNonMemb(pk *cryptoffi.VrfPublicKey, uid uint64, ver uint64, dig []byte, nonMemb *NonMemb) bool {
if checkVrf(pk, uid, ver, nonMemb.Label, nonMemb.VrfProof) {
return true
}
return merkle.CheckProof(false, nonMemb.MerkProof, nonMemb.Label, nil, dig)
Expand All @@ -114,7 +114,7 @@ func (c *Client) Put(pk []byte) (uint64, *clientErr) {
return 0, err1
}
// check latest entry has right ver, epoch, pk.
if c.checkMemb(c.uid, c.nextVer, dig.Dig, latest) {
if checkMemb(c.servVrfPk, c.uid, c.nextVer, dig.Dig, latest) {
return 0, stdErr
}
if dig.Epoch != latest.EpochAdded {
Expand All @@ -124,7 +124,7 @@ func (c *Client) Put(pk []byte) (uint64, *clientErr) {
return 0, stdErr
}
// check bound has right ver.
if c.checkNonMemb(c.uid, c.nextVer+1, dig.Dig, bound) {
if checkNonMemb(c.servVrfPk, c.uid, c.nextVer+1, dig.Dig, bound) {
return 0, stdErr
}
c.nextVer += 1
Expand All @@ -146,7 +146,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) {
if err1.err {
return false, nil, 0, err1
}
if c.checkHist(uid, dig.Dig, hist) {
if checkHist(c.servVrfPk, uid, dig.Dig, hist) {
return false, nil, 0, stdErr
}
numHistVers := uint64(len(hist))
Expand All @@ -155,7 +155,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) {
return false, nil, 0, stdErr
}
// check latest has right ver.
if isReg && c.checkMemb(uid, numHistVers, dig.Dig, latest) {
if isReg && checkMemb(c.servVrfPk, uid, numHistVers, dig.Dig, latest) {
return false, nil, 0, stdErr
}
// check bound has right ver.
Expand All @@ -164,7 +164,7 @@ func (c *Client) Get(uid uint64) (bool, []byte, uint64, *clientErr) {
if isReg {
boundVer = numHistVers + 1
}
if c.checkNonMemb(uid, boundVer, dig.Dig, bound) {
if checkNonMemb(c.servVrfPk, uid, boundVer, dig.Dig, bound) {
return false, nil, 0, stdErr
}
return isReg, latest.CommOpen.Pk, dig.Epoch, &clientErr{err: false}
Expand All @@ -182,7 +182,7 @@ func (c *Client) SelfMon() (uint64, *clientErr) {
if err1.err {
return 0, err1
}
if c.checkNonMemb(c.uid, c.nextVer, dig.Dig, bound) {
if checkNonMemb(c.servVrfPk, c.uid, c.nextVer, dig.Dig, bound) {
return 0, stdErr
}
return dig.Epoch, &clientErr{err: false}
Expand Down

0 comments on commit 1befdd6

Please sign in to comment.