From e7d6412b385260ea8653a7dcd1b34b18705dc84f Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Mon, 23 Oct 2023 13:58:19 -0700 Subject: [PATCH 1/3] Prevent nested contract deployments --- runtime/contract_update_test.go | 456 +++++++++++++++++++++++++++++--- runtime/environment.go | 14 +- runtime/stdlib/account.go | 15 +- 3 files changed, 440 insertions(+), 45 deletions(-) diff --git a/runtime/contract_update_test.go b/runtime/contract_update_test.go index b65ee3c662..2428cf4a84 100644 --- a/runtime/contract_update_test.go +++ b/runtime/contract_update_test.go @@ -236,63 +236,428 @@ func TestContractUpdateWithDependencies(t *testing.T) { require.NoError(t, err) } -func TestRuntimeInvalidContractRedeploy(t *testing.T) { +func TestRuntimeContractRedeployInSameTransaction(t *testing.T) { t.Parallel() - foo1 := []byte(` - access(all) - contract Foo { + t.Run("two additions", func(t *testing.T) { + foo1 := []byte(` access(all) - resource R { + contract Foo { access(all) - var x: Int + resource R { - init() { - self.x = 0 + access(all) + var x: Int + + init() { + self.x = 0 + } + } + + access(all) + fun createR(): @R { + return <-create R() } } + `) + foo2 := []byte(` access(all) - fun createR(): @R { - return <-create R() + contract Foo { + + access(all) + struct R { + access(all) + var x: Int + + init() { + self.x = 0 + } + } } - } - `) + `) + + tx := []byte(` + transaction(foo1: String, foo2: String) { + prepare(signer: AuthAccount) { + signer.contracts.add(name: "Foo", code: foo1.utf8) + signer.contracts.add(name: "Foo", code: foo2.utf8) + } + } + `) + + runtime := newTestInterpreterRuntime() + runtime.defaultConfig.AtreeValidationEnabled = false + + address := common.MustBytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + getSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { + return nil, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + // "delay" + return nil + }, + emitEvent: func(event cadence.Event) error { + return nil + }, + decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy - foo2 := []byte(` - access(all) - contract Foo { + err := runtime.ExecuteTransaction( + Script{ + Source: tx, + Arguments: encodeArgs([]cadence.Value{ + cadence.String(foo1), + cadence.String(foo2), + }), + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + RequireError(t, err) + require.ErrorContains(t, err, "cannot overwrite existing contract") + }) + + t.Run("add and update", func(t *testing.T) { + + foo1 := []byte(` access(all) - struct R { + contract Foo { + access(all) - var x: Int + resource R { - init() { - self.x = 0 + access(all) + var x: Int + + init() { + self.x = 0 + } + } + + access(all) + fun createR(): @R { + return <-create R() } } - } - `) + `) + + foo2 := []byte(` + access(all) + contract Foo { + + access(all) + struct R { + access(all) + var x: Int + + init() { + self.x = 0 + } + } + } + `) + + tx := []byte(` + transaction(foo1: String, foo2: String) { + prepare(signer: AuthAccount) { + signer.contracts.add(name: "Foo", code: foo1.utf8) + signer.contracts.update__experimental(name: "Foo", code: foo2.utf8) + } + } + `) + + runtime := newTestInterpreterRuntime() + runtime.defaultConfig.AtreeValidationEnabled = false + + address := common.MustBytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + getSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { + return nil, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + // "delay" + return nil + }, + emitEvent: func(event cadence.Event) error { + return nil + }, + decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy + + err := runtime.ExecuteTransaction( + Script{ + Source: tx, + Arguments: encodeArgs([]cadence.Value{ + cadence.String(foo1), + cadence.String(foo2), + }), + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + require.ErrorContains(t, err, "cannot update non-existing contract") + }) +} + +func TestRuntimeNestedContractDeployment(t *testing.T) { + + t.Parallel() + + t.Run("add while adding", func(t *testing.T) { + + t.Parallel() + + contract := []byte(` + access(all) contract Foo { + + access(all) resource Bar {} + + init(){ + self.account.contracts.add( + name: "Foo", + code: "access(all) contract Foo { access(all) struct Bar {} }".utf8 + ) + } + } + `) + + runtime := newTestInterpreterRuntime() + runtime.defaultConfig.AtreeValidationEnabled = false + + address := common.MustBytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + getSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { + return nil, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + // "delay" + return nil + }, + emitEvent: func(event cadence.Event) error { + return nil + }, + decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() - tx := []byte(` - transaction(foo1: String, foo2: String) { - prepare(signer: AuthAccount) { - signer.contracts.add(name: "Foo", code: foo1.utf8) - signer.contracts.add(name: "Foo", code: foo2.utf8) - } - } - `) + // Deploy + + deploymentTx := DeploymentTransaction("Foo", contract) + + err := runtime.ExecuteTransaction( + Script{ + Source: deploymentTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + require.ErrorContains(t, err, "cannot overwrite existing contract") + }) + + t.Run("update while adding", func(t *testing.T) { + + t.Parallel() + + contract := []byte(` + access(all) contract Foo { + + access(all) resource Bar {} + + init(){ + self.account.contracts.update__experimental( + name: "Foo", + code: "access(all) contract Foo { access(all) struct Bar {} }".utf8 + ) + } + } + `) + + runtime := newTestInterpreterRuntime() + runtime.defaultConfig.AtreeValidationEnabled = false + + address := common.MustBytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + getSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { + return nil, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + // "delay" + return nil + }, + emitEvent: func(event cadence.Event) error { + return nil + }, + decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy + + deploymentTx := DeploymentTransaction("Foo", contract) + + err := runtime.ExecuteTransaction( + Script{ + Source: deploymentTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + require.ErrorContains(t, err, "cannot update non-existing contract") + }) + + t.Run("update while updating", func(t *testing.T) { + + t.Parallel() + + deployedContract := []byte(` + access(all) contract Foo { + + access(all) resource Bar {} + + init() {} + } + `) + + contract := []byte(` + access(all) contract Foo { + + access(all) resource Bar {} + + init(){ + self.account.contracts.update__experimental( + name: "Foo", + code: "access(all) contract Foo { access(all) struct Bar {} }".utf8 + ) + } + } + `) + + runtime := newTestInterpreterRuntime() + runtime.defaultConfig.AtreeValidationEnabled = false + + address := common.MustBytesToAddress([]byte{0x1}) + + runtimeInterface := &testRuntimeInterface{ + storage: newTestLedger(nil, nil), + getSigningAccounts: func() ([]Address, error) { + return []Address{address}, nil + }, + getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { + return deployedContract, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + // "delay" + deployedContract = code + return nil + }, + emitEvent: func(event cadence.Event) error { + return nil + }, + decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + // Update + + updateTx := UpdateTransaction("Foo", contract) + + err := runtime.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + // OK: since the initializer never runs. + require.NoError(t, err) + }) +} + +func TestRuntimeContractRedeploymentInSeparateTransactions(t *testing.T) { + + t.Parallel() + + contract := []byte(` + access(all) contract Foo { + access(all) resource Bar {} + } + `) runtime := newTestInterpreterRuntime() runtime.defaultConfig.AtreeValidationEnabled = false address := common.MustBytesToAddress([]byte{0x1}) - var events []cadence.Event + var contractCode []byte runtimeInterface := &testRuntimeInterface{ storage: newTestLedger(nil, nil), @@ -300,15 +665,14 @@ func TestRuntimeInvalidContractRedeploy(t *testing.T) { return []Address{address}, nil }, getAccountContractCode: func(location common.AddressLocation) ([]byte, error) { - return nil, nil + return contractCode, nil }, resolveLocation: singleIdentifierLocationResolver(t), - updateAccountContractCode: func(location common.AddressLocation, code []byte) error { - // "delay" + updateAccountContractCode: func(_ common.AddressLocation, code []byte) error { + contractCode = code return nil }, emitEvent: func(event cadence.Event) error { - events = append(events, event) return nil }, decodeArgument: func(b []byte, t cadence.Type) (value cadence.Value, err error) { @@ -320,20 +684,30 @@ func TestRuntimeInvalidContractRedeploy(t *testing.T) { // Deploy + deploymentTx := DeploymentTransaction("Foo", contract) err := runtime.ExecuteTransaction( Script{ - Source: tx, - Arguments: encodeArgs([]cadence.Value{ - cadence.String(foo1), - cadence.String(foo2), - }), + Source: deploymentTx, }, Context{ Interface: runtimeInterface, Location: nextTransactionLocation(), }, ) - RequireError(t, err) + require.NoError(t, err) + + // Update + // Updating in a separate transaction is OK, and should not abort. - require.ErrorContains(t, err, "cannot overwrite existing contract") + updateTx := UpdateTransaction("Foo", contract) + err = runtime.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) } diff --git a/runtime/environment.go b/runtime/environment.go index 426f38dcfe..e84d3159a9 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -122,6 +122,7 @@ type interpreterEnvironment struct { stackDepthLimiter *stackDepthLimiter checkedImports importResolutionResults config Config + deployedContracts map[Location]struct{} } var _ Environment = &interpreterEnvironment{} @@ -424,8 +425,17 @@ func (e *interpreterEnvironment) RecordContractUpdate( e.storage.recordContractUpdate(location, contractValue) } -func (e *interpreterEnvironment) ContractUpdateRecorded(location common.AddressLocation) bool { - return e.storage.contractUpdateRecorded(location) +func (e *interpreterEnvironment) TrackContractAddition(location common.AddressLocation) { + if e.deployedContracts == nil { + e.deployedContracts = map[Location]struct{}{} + } + + e.deployedContracts[location] = struct{}{} +} + +func (e *interpreterEnvironment) ContractAdditionTracked(location common.AddressLocation) bool { + _, contains := e.deployedContracts[location] + return contains } func (e *interpreterEnvironment) TemporarilyRecordCode(location common.AddressLocation, code []byte) { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 2437f40be7..750cf88ee5 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1435,7 +1435,6 @@ type AccountContractAdditionHandler interface { location common.AddressLocation, value *interpreter.CompositeValue, ) - ContractUpdateRecorded(location common.AddressLocation) bool InterpretContract( location common.AddressLocation, program *interpreter.Program, @@ -1446,6 +1445,12 @@ type AccountContractAdditionHandler interface { error, ) TemporarilyRecordCode(location common.AddressLocation, code []byte) + + // TrackContractAddition records that the contract was added in the current execution. + TrackContractAddition(location common.AddressLocation) + + // ContractAdditionTracked check whether a contract has being added during the current execution. + ContractAdditionTracked(location common.AddressLocation) bool } // newAuthAccountContractsChangeFunction called when e.g. @@ -1521,7 +1526,7 @@ func newAuthAccountContractsChangeFunction( // Ensure that no contract/contract interface with the given name exists already, // and no contract deploy or update was recorded before - if len(existingCode) > 0 || handler.ContractUpdateRecorded(location) { + if len(existingCode) > 0 || handler.ContractAdditionTracked(location) { panic(errors.NewDefaultUserError( "cannot overwrite existing contract with name %q in account %s", contractName, @@ -1788,6 +1793,12 @@ func updateAccountContractCode( constructorArgumentTypes []sema.Type, options updateAccountContractCodeOptions, ) error { + + // Start tracking the contract addition. + // This must be done even before the contract code gets added, + // to avoid the same contract being updated during the deployment of itself. + handler.TrackContractAddition(location) + // If the code declares a contract, instantiate it and store it. // // This function might be called when From b0d4489bfef01935019a4ceead57386d1e9d334b Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Mon, 23 Oct 2023 15:40:54 -0700 Subject: [PATCH 2/3] Separate nested update vs sequential update tracking --- runtime/environment.go | 12 ++++++++++-- runtime/stdlib/account.go | 19 +++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/runtime/environment.go b/runtime/environment.go index e84d3159a9..01a9b10aea 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -425,7 +425,11 @@ func (e *interpreterEnvironment) RecordContractUpdate( e.storage.recordContractUpdate(location, contractValue) } -func (e *interpreterEnvironment) TrackContractAddition(location common.AddressLocation) { +func (e *interpreterEnvironment) ContractUpdateRecorded(location common.AddressLocation) bool { + return e.storage.contractUpdateRecorded(location) +} + +func (e *interpreterEnvironment) StartContractAddition(location common.AddressLocation) { if e.deployedContracts == nil { e.deployedContracts = map[Location]struct{}{} } @@ -433,7 +437,11 @@ func (e *interpreterEnvironment) TrackContractAddition(location common.AddressLo e.deployedContracts[location] = struct{}{} } -func (e *interpreterEnvironment) ContractAdditionTracked(location common.AddressLocation) bool { +func (e *interpreterEnvironment) EndContractAddition(location common.AddressLocation) { + delete(e.deployedContracts, location) +} + +func (e *interpreterEnvironment) IsContractBeingAdded(location common.AddressLocation) bool { _, contains := e.deployedContracts[location] return contains } diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 750cf88ee5..63cf8fdcdf 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1435,6 +1435,7 @@ type AccountContractAdditionHandler interface { location common.AddressLocation, value *interpreter.CompositeValue, ) + ContractUpdateRecorded(location common.AddressLocation) bool InterpretContract( location common.AddressLocation, program *interpreter.Program, @@ -1446,11 +1447,14 @@ type AccountContractAdditionHandler interface { ) TemporarilyRecordCode(location common.AddressLocation, code []byte) - // TrackContractAddition records that the contract was added in the current execution. - TrackContractAddition(location common.AddressLocation) + // StartContractAddition start adding a contract. + StartContractAddition(location common.AddressLocation) - // ContractAdditionTracked check whether a contract has being added during the current execution. - ContractAdditionTracked(location common.AddressLocation) bool + // EndContractAddition end adding the contract + EndContractAddition(location common.AddressLocation) + + // IsContractBeingAdded check whether a contract is being added in the current execution. + IsContractBeingAdded(location common.AddressLocation) bool } // newAuthAccountContractsChangeFunction called when e.g. @@ -1526,7 +1530,9 @@ func newAuthAccountContractsChangeFunction( // Ensure that no contract/contract interface with the given name exists already, // and no contract deploy or update was recorded before - if len(existingCode) > 0 || handler.ContractAdditionTracked(location) { + if len(existingCode) > 0 || + handler.ContractUpdateRecorded(location) || + handler.IsContractBeingAdded(location) { panic(errors.NewDefaultUserError( "cannot overwrite existing contract with name %q in account %s", contractName, @@ -1797,7 +1803,8 @@ func updateAccountContractCode( // Start tracking the contract addition. // This must be done even before the contract code gets added, // to avoid the same contract being updated during the deployment of itself. - handler.TrackContractAddition(location) + handler.StartContractAddition(location) + defer handler.EndContractAddition(location) // If the code declares a contract, instantiate it and store it. // From 5f8f09855a635597eeaf301ebc0e7ef0a686bc30 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Mon, 23 Oct 2023 15:47:40 -0700 Subject: [PATCH 3/3] Refactor code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Bastian Müller --- runtime/stdlib/account.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 63cf8fdcdf..45b0a17d2a 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1447,13 +1447,13 @@ type AccountContractAdditionHandler interface { ) TemporarilyRecordCode(location common.AddressLocation, code []byte) - // StartContractAddition start adding a contract. + // StartContractAddition starts adding a contract. StartContractAddition(location common.AddressLocation) - // EndContractAddition end adding the contract + // EndContractAddition ends adding the contract EndContractAddition(location common.AddressLocation) - // IsContractBeingAdded check whether a contract is being added in the current execution. + // IsContractBeingAdded checks whether a contract is being added in the current execution. IsContractBeingAdded(location common.AddressLocation) bool } @@ -1533,6 +1533,7 @@ func newAuthAccountContractsChangeFunction( if len(existingCode) > 0 || handler.ContractUpdateRecorded(location) || handler.IsContractBeingAdded(location) { + panic(errors.NewDefaultUserError( "cannot overwrite existing contract with name %q in account %s", contractName,