Skip to content

Commit

Permalink
Merge pull request ethereum#38 from zama-ai/petar/do-not-import-rando…
Browse files Browse the repository at this point in the history
…m-ct-on-eth-call

Do not import random ciphertext on EthCall
  • Loading branch information
dartdart26 authored Feb 2, 2023
2 parents d42bc96 + b49c344 commit 3af35d4
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 10 deletions.
69 changes: 59 additions & 10 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
}

// PrecompiledContractsByzantium contains the default set of pre-compiled Ethereum
Expand All @@ -95,6 +96,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
}

// PrecompiledContractsIstanbul contains the default set of pre-compiled Ethereum
Expand All @@ -119,6 +121,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
}

// PrecompiledContractsBerlin contains the default set of pre-compiled Ethereum
Expand All @@ -143,6 +146,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
}

// PrecompiledContractsBLS contains the set of pre-compiled Ethereum
Expand All @@ -167,6 +171,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
common.BytesToAddress([]byte{73}): &fheLt{},
}

var (
Expand Down Expand Up @@ -1239,8 +1244,8 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("unverified ciphertext handle")
}

// If we are not committing state, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}

Expand Down Expand Up @@ -1306,8 +1311,8 @@ func verifyZkProof(input []byte) ([]byte, error) {
}

func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
// If we are not committing state, skip verificaton and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}
var ctBytes []byte
Expand Down Expand Up @@ -1521,8 +1526,8 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("unverified ciphertext handle")
}

// If we are not committing state, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}

Expand Down Expand Up @@ -1565,8 +1570,8 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("unverified ciphertext handle")
}

// If we are not committing state, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}

Expand Down Expand Up @@ -1609,8 +1614,8 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("unverified ciphertext handle")
}

// If we are not committing state, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}

Expand All @@ -1631,3 +1636,47 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad

return ctHash[:], nil
}

type fheLt struct{}

func (e *fheLt) RequiredGas(input []byte) uint64 {
// TODO
return 8
}

func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
if len(input) != 64 {
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.lt(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}

// TODO: for testing
err := os.WriteFile("/tmp/lt_result", verifiedCiphertext.ciphertext.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
42 changes: 42 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,48 @@ func TestFheLte(t *testing.T) {
}
}

func TestFheLt(t *testing.T) {
c := &fheLt{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)

// 2 < 1
input1 := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 0 {
t.Fatalf("invalid decrypted result")
}

// 1 < 2
input2 := toPrecompileInput(rhs_hash, lhs_hash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists = state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted = res.ciphertext.decrypt()
if decrypted != 1 {
t.Fatalf("invalid decrypted result")
}
}

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newTestState()
Expand Down
17 changes: 17 additions & 0 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ void* tfhe_lte(void* sks, void* ct1, void* ct2)
return result;
}
void* tfhe_lt(void* sks, void* ct1, void* ct2)
{
ShortintCiphertext *result = NULL;
const int r = shortint_bc_server_key_smart_less(sks, ct1, ct2, &result);
assert(r == 0);
return result;
}
uint64_t decrypt(void* cks, void* ct)
{
uint64_t res = 0;
Expand Down Expand Up @@ -297,6 +305,15 @@ func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) *tfheCiphertext {
return res
}

func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) *tfheCiphertext {
if !lhs.availableForOps() || !rhs.availableForOps() {
panic("cannot lt on a non-initialized ciphertext")
}
res := new(tfheCiphertext)
res.setPtr(C.tfhe_lt(sks, lhs.ptr, rhs.ptr))
return res
}

func (ct *tfheCiphertext) decrypt() uint64 {
if !ct.availableForOps() {
panic("cannot decrypt a null ciphertext")
Expand Down
18 changes: 18 additions & 0 deletions core/vm/tfhe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,21 @@ func TestTfheLte(t *testing.T) {
t.Fatalf("%d != %d", 0, res2)
}
}
func TestTfheLt(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a)
ctB := new(tfheCiphertext)
ctB.encrypt(b)
ctRes1 := ctA.lte(ctB)
ctRes2 := ctB.lte(ctA)
res1 := ctRes1.decrypt()
res2 := ctRes2.decrypt()
if res1 != 0 {
t.Fatalf("%d != %d", 0, res1)
}
if res2 != 1 {
t.Fatalf("%d != %d", 0, res2)
}
}

0 comments on commit 3af35d4

Please sign in to comment.