Skip to content

Commit

Permalink
Merge pull request ethereum#88 from zama-ai/louis-test-types-precompiles
Browse files Browse the repository at this point in the history
feat(testing): test types in precompiles
  • Loading branch information
tremblaythibaultl authored May 9, 2023
2 parents 69efcf2 + dca7611 commit bc38b5f
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 247 deletions.
168 changes: 146 additions & 22 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,29 @@ func toPrecompileInput(hashes ...common.Hash) []byte {
return ret
}

func TestFheAdd(t *testing.T) {
func FheAdd(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}
expected := lhs + rhs
c := &fheAdd{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()
input := toPrecompileInput(lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
Expand All @@ -473,21 +486,34 @@ func TestFheAdd(t *testing.T) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 2 {
if decrypted != uint64(expected) {
t.Fatalf("invalid decrypted result")
}
}

func TestFheSub(t *testing.T) {
func FheSub(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}
expected := lhs - rhs
c := &fheSub{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()
input := toPrecompileInput(lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
Expand All @@ -498,21 +524,34 @@ func TestFheSub(t *testing.T) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 1 {
if decrypted != expected {
t.Fatalf("invalid decrypted result")
}
}

func TestFheMul(t *testing.T) {
func FheMul(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 169
rhs = 5
case FheUint32:
lhs = 137
rhs = 17
}
expected := lhs * rhs
c := &fheMul{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()
input := toPrecompileInput(lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
Expand All @@ -523,23 +562,35 @@ func TestFheMul(t *testing.T) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 2 {
if decrypted != expected {
t.Fatalf("invalid decrypted result")
}
}

func TestFheLte(t *testing.T) {
func FheLte(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}
c := &fheLte{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()

// 2 <= 1
// lhs <= rhs
input1 := toPrecompileInput(lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
Expand All @@ -554,7 +605,7 @@ func TestFheLte(t *testing.T) {
t.Fatalf("invalid decrypted result")
}

// 1 <= 2
// rhs <= lhs
input2 := toPrecompileInput(rhsHash, lhsHash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
Expand All @@ -570,18 +621,31 @@ func TestFheLte(t *testing.T) {
}
}

func TestFheLt(t *testing.T) {
func FheLt(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 2
rhs = 1
case FheUint16:
lhs = 4283
rhs = 1337
case FheUint32:
lhs = 1333337
rhs = 133337
}

c := &fheLt{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash()
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()

// 2 < 1
// lhs < rhs
input1 := toPrecompileInput(lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
Expand All @@ -596,7 +660,7 @@ func TestFheLt(t *testing.T) {
t.Fatalf("invalid decrypted result")
}

// 1 < 2
// rhs < lhs
input2 := toPrecompileInput(rhsHash, lhsHash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
Expand All @@ -612,6 +676,66 @@ func TestFheLt(t *testing.T) {
}
}

func TestFheAdd8(t *testing.T) {
FheAdd(t, FheUint8)
}

func TestFheSub8(t *testing.T) {
FheSub(t, FheUint8)
}

func TestFheMul8(t *testing.T) {
FheMul(t, FheUint8)
}

func TestFheLte8(t *testing.T) {
FheLte(t, FheUint8)
}

func TestFheLt8(t *testing.T) {
FheLt(t, FheUint8)
}

func TestFheAdd16(t *testing.T) {
FheAdd(t, FheUint16)
}

func TestFheSub16(t *testing.T) {
FheSub(t, FheUint16)
}

func TestFheMul16(t *testing.T) {
FheMul(t, FheUint16)
}

func TestFheLte16(t *testing.T) {
FheLte(t, FheUint16)
}

func TestFheLt16(t *testing.T) {
FheLt(t, FheUint16)
}

func TestFheAdd32(t *testing.T) {
FheAdd(t, FheUint32)
}

func TestFheSub32(t *testing.T) {
FheSub(t, FheUint32)
}

func TestFheMul32(t *testing.T) {
FheMul(t, FheUint32)
}

func TestFheLte32(t *testing.T) {
FheLte(t, FheUint32)
}

func TestFheLt32(t *testing.T) {
FheLt(t, FheUint32)
}

// func TestFheRand(t *testing.T) {
// c := &fheRand{}
// depth := 1
Expand Down
Loading

0 comments on commit bc38b5f

Please sign in to comment.