diff --git a/packages/evm/src/evm.ts b/packages/evm/src/evm.ts index a9266197e9..272be71e88 100644 --- a/packages/evm/src/evm.ts +++ b/packages/evm/src/evm.ts @@ -733,6 +733,10 @@ export class EVM implements EVMInterface { } } + if (message.depth === 0) { + this.postMessageCleanup() + } + return { createdAddress: message.to, execResult: result, diff --git a/packages/evm/test/transientStorage.spec.ts b/packages/evm/test/transientStorage.spec.ts index 095b4b66b3..826751034d 100644 --- a/packages/evm/test/transientStorage.spec.ts +++ b/packages/evm/test/transientStorage.spec.ts @@ -1,6 +1,7 @@ -import { Address } from '@ethereumjs/util' +import { Address, equalsBytes, hexToBytes, setLengthLeft, unpadBytes } from '@ethereumjs/util' import { assert, describe, it } from 'vitest' +import { EVM } from '../src/index.js' import { TransientStorage } from '../src/transientStorage.js' describe('Transient Storage', () => { @@ -176,4 +177,38 @@ describe('Transient Storage', () => { transientStorage.revert() assert.deepEqual(transientStorage.get(address, key), value1) }) + + it('should cleanup after a message create', async () => { + const evm = await EVM.create() + // PUSH 1 PUSH 1 TSTORE + const code = hexToBytes('0x600160015D') + const keyBuf = setLengthLeft(new Uint8Array([1]), 32) + const result = await evm.runCall({ + data: code, + gasLimit: BigInt(100_000), + }) + const created = result.createdAddress! + const stored = evm.transientStorage.get(created, keyBuf) + assert.ok( + equalsBytes(unpadBytes(stored), new Uint8Array()), + 'Transient storage has been cleared' + ) + }) + + it('should cleanup after a message call', async () => { + const evm = await EVM.create() + const contractAddress = Address.zero() + // PUSH 1 PUSH 1 TSTORE + const code = hexToBytes('0x600160015D') + await evm.stateManager.putContractCode(contractAddress, code) + const keyBuf = setLengthLeft(new Uint8Array([1]), 32) + await evm.runCall({ + gasLimit: BigInt(100_000), + }) + const stored = evm.transientStorage.get(contractAddress, keyBuf) + assert.ok( + equalsBytes(unpadBytes(stored), new Uint8Array()), + 'Transient storage has been cleared' + ) + }) })