Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ETHSF_HACKATHON_SUBMISSION] (EIP-1153) Transient Storage in EVM #25

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/state_transition.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ func (st *StateTransition) innerTransitionDb() (*ExecutionResult, error) {
} else {
// Increment the nonce for the next transaction
st.state.SetNonce(msg.From(), st.state.GetNonce(sender.Address())+1)
ret, st.gas, vmerr = st.evm.Call(sender, st.to(), st.data, st.gas, st.value)
ret, st.gas, vmerr = st.evm.Call(sender, vm.NewTransientStorage(), st.to(), st.data, st.gas, st.value)
}

// if deposit: skip refunds, skip tipping coinbase
Expand Down
43 changes: 34 additions & 9 deletions core/vm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (evm *EVM) Interpreter() *EVMInterpreter {
// parameters. It also handles any necessary value transfer required and takes
// the necessary steps to create accounts and reverses the state in case of an
// execution error or failed value transfer.
func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, leftOverGas uint64, err error) {
func (evm *EVM) Call(caller ContractRef, transientStorage *TransientStorage, addr common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, leftOverGas uint64, err error) {
// Fail if we're trying to execute above the call depth limit
if evm.depth > int(params.CallCreateDepth) {
return nil, gas, ErrDepth
Expand All @@ -184,6 +184,10 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
if value.Sign() != 0 && !evm.Context.CanTransfer(evm.StateDB, caller.Address(), value) {
return nil, gas, ErrInsufficientBalance
}

println("Call from: ", caller.Address().String())
println("Call to:", addr.String())

snapshot := evm.StateDB.Snapshot()
p, isPrecompile := evm.precompile(addr)

Expand Down Expand Up @@ -235,21 +239,27 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
// The depth-check is already done, and precompiles handled above
contract := NewContract(caller, AccountRef(addrCopy), value, gas)
contract.SetCallCode(&addrCopy, evm.StateDB.GetCodeHash(addrCopy), code)
ret, err = evm.interpreter.Run(contract, input, false)

transientStorage.CheckPoint()
ret, err = evm.interpreter.Run(contract, transientStorage, input, false)
gas = contract.Gas
}
}
// When an error was returned by the EVM or when setting the creation code
// above we revert to the snapshot and consume any gas remaining. Additionally
// when we're in homestead this also counts for code storage gas errors.
if err != nil {
transientStorage.Revert()
evm.StateDB.RevertToSnapshot(snapshot)
if err != ErrExecutionReverted {
gas = 0
}

// TODO: consider clearing up unused snapshots:
//} else {
// evm.StateDB.DiscardSnapshot(snapshot)
} else {
transientStorage.Commit()
}
return ret, gas, err
}
Expand All @@ -261,7 +271,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
//
// CallCode differs from Call in the sense that it executes the given address'
// code with the caller as context.
func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, leftOverGas uint64, err error) {
func (evm *EVM) CallCode(caller ContractRef, transientStorage *TransientStorage, addr common.Address, input []byte, gas uint64, value *big.Int) (ret []byte, leftOverGas uint64, err error) {
// Fail if we're trying to execute above the call depth limit
if evm.depth > int(params.CallCreateDepth) {
return nil, gas, ErrDepth
Expand Down Expand Up @@ -292,10 +302,12 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte,
// The contract is a scoped environment for this execution context only.
contract := NewContract(caller, AccountRef(caller.Address()), value, gas)
contract.SetCallCode(&addrCopy, evm.StateDB.GetCodeHash(addrCopy), evm.StateDB.GetCode(addrCopy))
ret, err = evm.interpreter.Run(contract, input, false)
ret, err = evm.interpreter.Run(contract, transientStorage, input, false)
gas = contract.Gas
}
if err != nil {

transientStorage.Revert()
evm.StateDB.RevertToSnapshot(snapshot)
if err != ErrExecutionReverted {
gas = 0
Expand All @@ -309,7 +321,7 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte,
//
// DelegateCall differs from CallCode in the sense that it executes the given address'
// code with the caller as context and the caller is set to the caller of the caller.
func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []byte, gas uint64) (ret []byte, leftOverGas uint64, err error) {
func (evm *EVM) DelegateCall(caller ContractRef, transientStorage *TransientStorage, addr common.Address, input []byte, gas uint64) (ret []byte, leftOverGas uint64, err error) {
// Fail if we're trying to execute above the call depth limit
if evm.depth > int(params.CallCreateDepth) {
return nil, gas, ErrDepth
Expand All @@ -332,10 +344,12 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by
// Initialise a new contract and make initialise the delegate values
contract := NewContract(caller, AccountRef(caller.Address()), nil, gas).AsDelegate()
contract.SetCallCode(&addrCopy, evm.StateDB.GetCodeHash(addrCopy), evm.StateDB.GetCode(addrCopy))
ret, err = evm.interpreter.Run(contract, input, false)
ret, err = evm.interpreter.Run(contract, transientStorage, input, false)
gas = contract.Gas
}
if err != nil {

transientStorage.Revert()
evm.StateDB.RevertToSnapshot(snapshot)
if err != ErrExecutionReverted {
gas = 0
Expand All @@ -348,7 +362,7 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by
// as parameters while disallowing any modifications to the state during the call.
// Opcodes that attempt to perform such modifications will result in exceptions
// instead of performing the modifications.
func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte, gas uint64) (ret []byte, leftOverGas uint64, err error) {
func (evm *EVM) StaticCall(caller ContractRef, transientStorage *TransientStorage, addr common.Address, input []byte, gas uint64) (ret []byte, leftOverGas uint64, err error) {
// Fail if we're trying to execute above the call depth limit
if evm.depth > int(params.CallCreateDepth) {
return nil, gas, ErrDepth
Expand Down Expand Up @@ -388,15 +402,21 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte
// When an error was returned by the EVM or when setting the creation code
// above we revert to the snapshot and consume any gas remaining. Additionally
// when we're in Homestead this also counts for code storage gas errors.
ret, err = evm.interpreter.Run(contract, input, true)

transientStorage.CheckPoint()
ret, err = evm.interpreter.Run(contract, transientStorage, input, true)
gas = contract.Gas
}
if err != nil {
transientStorage.Revert()
evm.StateDB.RevertToSnapshot(snapshot)
if err != ErrExecutionReverted {
gas = 0
}
} else {
transientStorage.Commit()
}

return ret, gas, err
}

Expand Down Expand Up @@ -460,7 +480,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64,

start := time.Now()

ret, err := evm.interpreter.Run(contract, nil, false)
ret, err := evm.interpreter.Run(contract, nil, nil, false)

// Check whether the max code size has been exceeded, assign err if the case.
if err == nil && evm.chainRules.IsEIP158 && len(ret) > params.MaxCodeSize {
Expand Down Expand Up @@ -523,3 +543,8 @@ func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment *

// ChainConfig returns the environment's chain configuration
func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig }

func (evm *EVM) SetGasCall(gas uint64) {
evm.callGasTemp = gas
evm.interpreter.evm.callGasTemp = gas
}
2 changes: 1 addition & 1 deletion core/vm/gas_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestEIP2200(t *testing.T) {
}
vmenv := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}})

_, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, tt.gaspool, new(big.Int))
_, gas, err := vmenv.Call(AccountRef(common.Address{}), NewTransientStorage(), address, nil, tt.gaspool, new(big.Int))
if err != tt.failure {
t.Errorf("test %d: failure mismatch: have %v, want %v", i, err, tt.failure)
}
Expand Down
32 changes: 28 additions & 4 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,29 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b
return nil, nil
}

func opTStore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {
if interpreter.readOnly {
return nil, ErrWriteProtection
}

loc := scope.Stack.pop()
val := scope.Stack.pop()
scope.TransientStorage.Set(scope.Contract.Address(), loc.Bytes32(), val.Bytes32())

return nil, nil
}

func opTLoad(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {

loc := scope.Stack.peek()
hash := common.Hash(loc.Bytes32())
value := scope.TransientStorage.Get(scope.Contract.Address(), hash)

loc.SetBytes(value.Bytes())

return nil, nil
}

func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {
if atomic.LoadInt32(&interpreter.evm.abort) != 0 {
return nil, errStopToken
Expand Down Expand Up @@ -688,7 +711,8 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byt
bigVal = value.ToBig()
}

ret, returnGas, err := interpreter.evm.Call(scope.Contract, toAddr, args, gas, bigVal)
println("Callling to ", toAddr.String())
ret, returnGas, err := interpreter.evm.Call(scope.Contract, scope.TransientStorage, toAddr, args, gas, bigVal)

if err != nil {
temp.Clear()
Expand Down Expand Up @@ -725,7 +749,7 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([
bigVal = value.ToBig()
}

ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, toAddr, args, gas, bigVal)
ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, scope.TransientStorage, toAddr, args, gas, bigVal)
if err != nil {
temp.Clear()
} else {
Expand Down Expand Up @@ -754,7 +778,7 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, toAddr, args, gas)
ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, scope.TransientStorage, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
Expand Down Expand Up @@ -783,7 +807,7 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext)
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, toAddr, args, gas)
ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, scope.TransientStorage, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
Expand Down
18 changes: 9 additions & 9 deletions core/vm/instructions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFu
expected := new(uint256.Int).SetBytes(common.Hex2Bytes(test.Expected))
stack.push(x)
stack.push(y)
opFn(&pc, evmInterpreter, &ScopeContext{nil, stack, nil})
opFn(&pc, evmInterpreter, &ScopeContext{nil, stack, nil, nil})
if len(stack.data) != 1 {
t.Errorf("Expected one item on stack after %v, got %d: ", name, len(stack.data))
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func TestAddMod(t *testing.T) {
stack.push(z)
stack.push(y)
stack.push(x)
opAddmod(&pc, evmInterpreter, &ScopeContext{nil, stack, nil})
opAddmod(&pc, evmInterpreter, &ScopeContext{nil, stack, nil, nil})
actual := stack.pop()
if actual.Cmp(expected) != 0 {
t.Errorf("Testcase %d, expected %x, got %x", i, expected, actual)
Expand All @@ -246,7 +246,7 @@ func TestWriteExpectedValues(t *testing.T) {
y := new(uint256.Int).SetBytes(common.Hex2Bytes(param.y))
stack.push(x)
stack.push(y)
opFn(&pc, interpreter, &ScopeContext{nil, stack, nil})
opFn(&pc, interpreter, &ScopeContext{nil, stack, nil, nil})
actual := stack.pop()
result[i] = TwoOperandTestcase{param.x, param.y, fmt.Sprintf("%064x", actual)}
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func opBenchmark(bench *testing.B, op executionFunc, args ...string) {
var (
env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{})
stack = newstack()
scope = &ScopeContext{nil, stack, nil}
scope = &ScopeContext{nil, stack, nil, nil}
evmInterpreter = NewEVMInterpreter(env, env.Config)
)

Expand Down Expand Up @@ -533,13 +533,13 @@ func TestOpMstore(t *testing.T) {
v := "abcdef00000000000000abba000000000deaf000000c0de00100000000133700"
stack.push(new(uint256.Int).SetBytes(common.Hex2Bytes(v)))
stack.push(new(uint256.Int))
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil})
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil, nil})
if got := common.Bytes2Hex(mem.GetCopy(0, 32)); got != v {
t.Fatalf("Mstore fail, got %v, expected %v", got, v)
}
stack.push(new(uint256.Int).SetUint64(0x1))
stack.push(new(uint256.Int))
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil})
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil, nil})
if common.Bytes2Hex(mem.GetCopy(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" {
t.Fatalf("Mstore failed to overwrite previous value")
}
Expand All @@ -563,7 +563,7 @@ func BenchmarkOpMstore(bench *testing.B) {
for i := 0; i < bench.N; i++ {
stack.push(value)
stack.push(memStart)
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil})
opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil, nil})
}
}

Expand All @@ -583,7 +583,7 @@ func BenchmarkOpKeccak256(bench *testing.B) {
for i := 0; i < bench.N; i++ {
stack.push(uint256.NewInt(32))
stack.push(start)
opKeccak256(&pc, evmInterpreter, &ScopeContext{mem, stack, nil})
opKeccak256(&pc, evmInterpreter, &ScopeContext{mem, stack, nil, nil})
}
}

Expand Down Expand Up @@ -678,7 +678,7 @@ func TestRandom(t *testing.T) {
pc = uint64(0)
evmInterpreter = env.interpreter
)
opRandom(&pc, evmInterpreter, &ScopeContext{nil, stack, nil})
opRandom(&pc, evmInterpreter, &ScopeContext{nil, stack, nil, nil})
if len(stack.data) != 1 {
t.Errorf("Expected one item on stack after %v, got %d: ", tt.name, len(stack.data))
}
Expand Down
23 changes: 13 additions & 10 deletions core/vm/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ type Config struct {
// ScopeContext contains the things that are per-call, such as stack and memory,
// but not transients like pc and gas
type ScopeContext struct {
Memory *Memory
Stack *Stack
Contract *Contract
Memory *Memory
Stack *Stack
Contract *Contract
TransientStorage *TransientStorage
}

// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports
Expand Down Expand Up @@ -113,7 +114,7 @@ func NewEVMInterpreter(evm *EVM, cfg Config) *EVMInterpreter {
// It's important to note that any errors returned by the interpreter should be
// considered a revert-and-consume-all-gas operation except for
// ErrExecutionReverted which means revert-and-keep-gas-left.
func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (ret []byte, err error) {
func (in *EVMInterpreter) Run(contract *Contract, transientStorage *TransientStorage, input []byte, readOnly bool) (ret []byte, err error) {
// Increment the call depth which is restricted to 1024
in.evm.depth++
defer func() { in.evm.depth-- }()
Expand All @@ -135,13 +136,15 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (
}

var (
op OpCode // current opcode
mem = NewMemory() // bound memory
stack = newstack() // local stack
op OpCode // current opcode
mem = NewMemory() // bound memory
stack = newstack() // local stack

callContext = &ScopeContext{
Memory: mem,
Stack: stack,
Contract: contract,
Memory: mem,
Stack: stack,
Contract: contract,
TransientStorage: transientStorage,
}
// For optimisation reason we're using uint64 as the program counter.
// It's theoretically possible to go above 2^64. The YP defines the PC
Expand Down
2 changes: 1 addition & 1 deletion core/vm/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestLoopInterrupt(t *testing.T) {
timeout := make(chan bool)

go func(evm *EVM) {
_, _, err := evm.Call(AccountRef(common.Address{}), address, nil, math.MaxUint64, new(big.Int))
_, _, err := evm.Call(AccountRef(common.Address{}), NewTransientStorage(), address, nil, math.MaxUint64, new(big.Int))
errChannel <- err
}(evm)

Expand Down
12 changes: 12 additions & 0 deletions core/vm/jump_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,18 @@ func newFrontierInstructionSet() JumpTable {
minStack: minStack(2, 0),
maxStack: maxStack(2, 0),
},
TLOAD: {
execute: opTLoad,
constantGas: params.WarmStorageReadCostEIP2929,
minStack: minStack(1, 1),
maxStack: minStack(1, 1),
},
TSTORE: {
execute: opTStore,
constantGas: params.WarmStorageReadCostEIP2929,
minStack: minStack(2, 0),
maxStack: minStack(2, 0),
},
JUMP: {
execute: opJump,
constantGas: GasMidStep,
Expand Down
4 changes: 4 additions & 0 deletions core/vm/opcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ const (
MSTORE8 OpCode = 0x53
SLOAD OpCode = 0x54
SSTORE OpCode = 0x55
TLOAD OpCode = 0xb3
TSTORE OpCode = 0xb4
JUMP OpCode = 0x56
JUMPI OpCode = 0x57
PC OpCode = 0x58
Expand Down Expand Up @@ -459,6 +461,8 @@ var stringToOp = map[string]OpCode{
"MSTORE8": MSTORE8,
"SLOAD": SLOAD,
"SSTORE": SSTORE,
"TLOAD": TLOAD,
"TSTORE": TSTORE,
"JUMP": JUMP,
"JUMPI": JUMPI,
"PC": PC,
Expand Down
Loading