diff --git a/core/state.go b/core/state.go index e00c518943..049914f4d8 100644 --- a/core/state.go +++ b/core/state.go @@ -323,6 +323,58 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return s.verifyStateUpdateRoot(update.NewRoot) } +func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { + reversed := *diff + + // storage diffs + reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) + for addr, storageDiffs := range diff.StorageDiffs { + reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) + for key := range storageDiffs { + value := &felt.Zero + if blockNumber > 0 { + oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) + if err != nil { + return nil, err + } + value = oldValue + } + reversedDiffs[key] = value + } + reversed.StorageDiffs[addr] = reversedDiffs + } + + // nonces + reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) + for addr := range diff.Nonces { + oldNonce := &felt.Zero + if blockNumber > 0 { + var err error + oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.Nonces[addr] = oldNonce + } + + // replaced + reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + for addr := range diff.ReplacedClasses { + classHash := &felt.Zero + if blockNumber > 0 { + var err error + classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.ReplacedClasses[addr] = classHash + } + + return &reversed, nil +} + var ( noClassContractsClassHash = new(felt.Felt).SetUint64(0) @@ -393,17 +445,27 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges return fmt.Errorf("contracts is nil") } - var err error + if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil { + return err + } + + if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil { + return err + } - // update contract class hashes - for addr, classHash := range diff.ReplacedClasses { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - return err - } - contracts[addr] = contract + return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts) +} + +func (s *State) updateContractClasses( + blockNumber uint64, + replacedClasses map[felt.Felt]*felt.Felt, + logChanges bool, + contracts map[felt.Felt]*StateContract, +) error { + for addr, classHash := range replacedClasses { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + return err } if logChanges { @@ -414,16 +476,19 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contract.ClassHash = classHash } + return nil +} - // update contract nonces - for addr, nonce := range diff.Nonces { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - return err - } - contracts[addr] = contract +func (s *State) updateContractNonces( + blockNumber uint64, + nonces map[felt.Felt]*felt.Felt, + logChanges bool, + contracts map[felt.Felt]*StateContract, +) error { + for addr, nonce := range nonces { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + return err } if logChanges { @@ -434,29 +499,43 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contract.Nonce = nonce } + return nil +} - // update contract storages - for addr, diff := range diff.StorageDiffs { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - // makes sure that all noClassContracts are deployed - if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { - contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) - } else { - return err - } +func (s *State) updateContractStorages( + blockNumber uint64, + storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt, + contracts map[felt.Felt]*StateContract, +) error { + for addr, diff := range storageDiffs { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { + contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) + contracts[addr] = contract + } else { + return err } - contracts[addr] = contract } contract.dirtyStorage = diff } - return nil } +func (s *State) getOrCreateContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { + contract, ok := contracts[addr] + if !ok { + var err error + contract, err = GetContract(&addr, s.txn) + if err != nil { + return nil, err + } + contracts[addr] = contract + } + return contract, nil +} + type DeclaredClass struct { At uint64 Class Class @@ -561,12 +640,15 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return fmt.Errorf("remove declared classes: %v", err) } - // update contracts - reversedDiff, err := s.buildReverseDiff(blockNumber, update.StateDiff) + reversedDiff, err := s.GetReverseStateDiff(blockNumber, update.StateDiff) if err != nil { return fmt.Errorf("build reverse diff: %v", err) } + if err = s.performStateDeletions(blockNumber, reversedDiff); err != nil { + return fmt.Errorf("perform state deletions: %v", err) + } + stateTrie, storageCloser, err := s.storage() if err != nil { return err @@ -588,30 +670,8 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { } } - // purge noClassContracts - // - // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. - // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, - // we can use the lack of key's existence as reason for purging noClassContracts. - for addr := range noClassContracts { - contract, err := GetContract(&addr, s.txn) - if err != nil { - if !errors.Is(err, ErrContractNotDeployed) { - return err - } - continue - } - - rootKey, err := contract.StorageRoot(s.txn) - if err != nil { - return fmt.Errorf("get root key: %v", err) - } - - if rootKey.Equal(&felt.Zero) { - if err = s.purgeContract(stateTrie, &addr); err != nil { - return fmt.Errorf("purge contract: %v", err) - } - } + if err = s.purgeNoClassContracts(stateTrie); err != nil { + return fmt.Errorf("purge no class contract: %v", err) } if err = storageCloser(); err != nil { @@ -673,69 +733,59 @@ func (s *State) purgeContract(stateTrie *trie.Trie, addr *felt.Felt) error { return nil } -func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { - reversed := *diff - +func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error { // storage diffs - reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) for addr, storageDiffs := range diff.StorageDiffs { - reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) for key := range storageDiffs { - value := &felt.Zero - if blockNumber > 0 { - oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) - if err != nil { - return nil, err - } - value = oldValue - } - if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { - return nil, err + return err } - reversedDiffs[key] = value } - reversed.StorageDiffs[addr] = reversedDiffs } // nonces - reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) for addr := range diff.Nonces { - oldNonce := &felt.Zero - - if blockNumber > 0 { - var err error - oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) - if err != nil { - return nil, err - } - } - if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { - return nil, err + return err } - reversed.Nonces[addr] = oldNonce } - // replaced - reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + // replaced classes for addr := range diff.ReplacedClasses { - classHash := &felt.Zero - if blockNumber > 0 { - var err error - classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) - if err != nil { - return nil, err + if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { + return err + } + } + + return nil +} + +func (s *State) purgeNoClassContracts(stateTrie *trie.Trie) error { + // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. + // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, + // we can use the lack of key's existence as reason for purging noClassContracts. + for addr := range noClassContracts { + contract, err := GetContract(&addr, s.txn) + if err != nil { + if !errors.Is(err, ErrContractNotDeployed) { + return err } + continue } - if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { - return nil, err + rootKey, err := contract.StorageRoot(s.txn) + if err != nil { + return fmt.Errorf("get root key: %v", err) + } + + if rootKey.Equal(&felt.Zero) { + if err = s.purgeContract(stateTrie, &addr); err != nil { + return fmt.Errorf("purge contract: %v", err) + } } - reversed.ReplacedClasses[addr] = classHash } - return &reversed, nil + return nil } func logDBKey(key []byte, height uint64) []byte {