From 67a9bbbc4d840858000aabb486438db2d7d435aa Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 16 Feb 2022 14:52:56 -0500 Subject: [PATCH 1/3] fix(dot/state): inject mutex protected tries to states (#2287) --- dot/rpc/modules/dev_integration_test.go | 5 ++- dot/state/block.go | 16 ++++---- dot/state/block_data_test.go | 2 +- dot/state/block_finalisation.go | 6 ++- dot/state/block_finalisation_test.go | 6 +-- dot/state/block_notify_test.go | 15 ++++---- dot/state/block_race_test.go | 4 +- dot/state/block_test.go | 28 +++++++------- dot/state/epoch_test.go | 6 ++- dot/state/helpers_test.go | 8 ++++ dot/state/initialize.go | 6 ++- dot/state/offline_pruner.go | 6 ++- dot/state/service.go | 9 ++--- dot/state/service_test.go | 2 +- dot/state/storage.go | 29 ++------------ dot/state/storage_test.go | 28 ++++++++------ dot/state/tries.go | 18 +++++---- dot/state/tries_test.go | 50 ++++++++++++------------- lib/grandpa/grandpa_test.go | 4 +- 19 files changed, 126 insertions(+), 122 deletions(-) diff --git a/dot/rpc/modules/dev_integration_test.go b/dot/rpc/modules/dev_integration_test.go index 90c2e60ef2..34c8d0a198 100644 --- a/dot/rpc/modules/dev_integration_test.go +++ b/dot/rpc/modules/dev_integration_test.go @@ -42,8 +42,9 @@ func newState(t *testing.T) (*state.BlockState, *state.EpochState) { db := state.NewInMemoryDB(t) - _, _, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t) - bs, err := state.NewBlockStateFromGenesis(db, genesisHeader, telemetryMock) + _, genesisTrie, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t) + tries := state.NewTries(genesisTrie) + bs, err := state.NewBlockStateFromGenesis(db, tries, genesisHeader, telemetryMock) require.NoError(t, err) es, err := state.NewEpochStateFromGenesis(db, bs, genesisBABEConfig) require.NoError(t, err) diff --git a/dot/state/block.go b/dot/state/block.go index e31d8f40c9..59f3db4521 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -27,8 +27,7 @@ import ( ) const ( - pruneKeyBufferSize = 1000 - blockPrefix = "block" + blockPrefix = "block" ) var ( @@ -60,6 +59,7 @@ type BlockState struct { genesisHash common.Hash lastFinalised common.Hash unfinalisedBlocks *sync.Map // map[common.Hash]*types.Block + tries *Tries // block notifiers imported map[chan *types.Block]struct{} @@ -69,21 +69,19 @@ type BlockState struct { runtimeUpdateSubscriptionsLock sync.RWMutex runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version - pruneKeyCh chan *types.Header - telemetry telemetry.Client } // NewBlockState will create a new BlockState backed by the database located at basePath -func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState, error) { +func NewBlockState(db chaindb.Database, trs *Tries, telemetry telemetry.Client) (*BlockState, error) { bs := &BlockState{ dbPath: db.Path(), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), unfinalisedBlocks: new(sync.Map), + tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), telemetry: telemetry, } @@ -107,16 +105,16 @@ func NewBlockState(db chaindb.Database, telemetry telemetry.Client) (*BlockState // NewBlockStateFromGenesis initialises a BlockState from a genesis header, // saving it to the database located at basePath -func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header, - telemetryMailer telemetry.Client) (*BlockState, error) { +func NewBlockStateFromGenesis(db chaindb.Database, trs *Tries, header *types.Header, + telemetryMailer telemetry.Client) (*BlockState, error) { // TODO CHECKTEST bs := &BlockState{ bt: blocktree.NewBlockTreeFromRoot(header), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), unfinalisedBlocks: new(sync.Map), + tries: trs, imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), - pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), genesisHash: header.Hash(), lastFinalised: header.Hash(), diff --git a/dot/state/block_data_test.go b/dot/state/block_data_test.go index 1a90dd1796..6ca6df58a3 100644 --- a/dot/state/block_data_test.go +++ b/dot/state/block_data_test.go @@ -15,7 +15,7 @@ import ( ) func TestGetSet_ReceiptMessageQueue_Justification(t *testing.T) { - s := newTestBlockState(t, nil) + s := newTestBlockState(t, nil, newTriesEmpty()) require.NotNil(t, s) var genesisHeader = &types.Header{ diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index ac19f469bc..f7fe569191 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -151,8 +151,9 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er continue } + bs.tries.delete(block.Header.StateRoot) + logger.Tracef("pruned block number %s with hash %s", block.Header.Number, hash) - bs.pruneKeyCh <- &block.Header } // if nothing was previously finalised, set the first slot of the network to the @@ -238,8 +239,9 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { continue } + bs.tries.delete(block.Header.StateRoot) + logger.Tracef("cleaned out finalised block from memory; block number %s with hash %s", block.Header.Number, hash) - bs.pruneKeyCh <- &block.Header } return batch.Flush() diff --git a/dot/state/block_finalisation_test.go b/dot/state/block_finalisation_test.go index 20775d9a41..b55466c374 100644 --- a/dot/state/block_finalisation_test.go +++ b/dot/state/block_finalisation_test.go @@ -13,7 +13,7 @@ import ( ) func TestHighestRoundAndSetID(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) round, setID, err := bs.GetHighestRoundAndSetID() require.NoError(t, err) require.Equal(t, uint64(0), round) @@ -61,7 +61,7 @@ func TestHighestRoundAndSetID(t *testing.T) { } func TestBlockState_SetFinalisedHash(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) h, err := bs.GetFinalisedHash(0, 0) require.NoError(t, err) require.Equal(t, testGenesisHeader.Hash(), h) @@ -97,7 +97,7 @@ func TestBlockState_SetFinalisedHash(t *testing.T) { } func TestSetFinalisedHash_setFirstSlotOnFinalisation(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) firstSlot := uint64(42069) digest := types.NewDigest() diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index 7dbd6a9fee..f476e8836e 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -12,13 +12,14 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/runtime" runtimemocks "github.com/ChainSafe/gossamer/lib/runtime/mocks" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/stretchr/testify/require" ) var testMessageTimeout = time.Second * 3 func TestImportChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetImportedBlockNotifierChannel() defer bs.FreeImportedBlockNotifierChannel(ch) @@ -35,7 +36,7 @@ func TestImportChannel(t *testing.T) { } func TestFreeImportedBlockNotifierChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetImportedBlockNotifierChannel() require.Equal(t, 1, len(bs.imported)) @@ -44,7 +45,7 @@ func TestFreeImportedBlockNotifierChannel(t *testing.T) { } func TestFinalizedChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := bs.GetFinalisedNotifierChannel() @@ -66,7 +67,7 @@ func TestFinalizedChannel(t *testing.T) { } func TestImportChannel_Multi(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) num := 5 chs := make([]chan *types.Block, num) @@ -99,7 +100,7 @@ func TestImportChannel_Multi(t *testing.T) { } func TestFinalizedChannel_Multi(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) num := 5 chs := make([]chan *types.FinalisationInfo, num) @@ -136,7 +137,7 @@ func TestFinalizedChannel_Multi(t *testing.T) { } func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) ch := make(chan<- runtime.Version) chID, err := bs.RegisterRuntimeUpdatedChannel(ch) require.NoError(t, err) @@ -147,7 +148,7 @@ func TestService_RegisterUnRegisterRuntimeUpdatedChannel(t *testing.T) { } func TestService_RegisterUnRegisterConcurrentCalls(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) go func() { for i := 0; i < 100; i++ { diff --git a/dot/state/block_race_test.go b/dot/state/block_race_test.go index 1f1284556c..c5bc2c1bc6 100644 --- a/dot/state/block_race_test.go +++ b/dot/state/block_race_test.go @@ -28,13 +28,15 @@ func TestConcurrencySetHeader(t *testing.T) { dbs[i] = NewInMemoryDB(t) } + tries := NewTries(trie.NewEmptyTrie()) // not used in this test + pend := new(sync.WaitGroup) pend.Add(threads) for i := 0; i < threads; i++ { go func(index int) { defer pend.Done() - bs, err := NewBlockStateFromGenesis(dbs[index], testGenesisHeader, telemetryMock) + bs, err := NewBlockStateFromGenesis(dbs[index], tries, testGenesisHeader, telemetryMock) require.NoError(t, err) header := &types.Header{ diff --git a/dot/state/block_test.go b/dot/state/block_test.go index 50e6513f5f..45508e8a53 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -25,7 +25,7 @@ var testGenesisHeader = &types.Header{ Digest: types.NewDigest(), } -func newTestBlockState(t *testing.T, header *types.Header) *BlockState { +func newTestBlockState(t *testing.T, header *types.Header, tries *Tries) *BlockState { ctrl := gomock.NewController(t) telemetryMock := NewMockClient(ctrl) telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() @@ -35,13 +35,13 @@ func newTestBlockState(t *testing.T, header *types.Header) *BlockState { header = testGenesisHeader } - bs, err := NewBlockStateFromGenesis(db, header, telemetryMock) + bs, err := NewBlockStateFromGenesis(db, tries, header, telemetryMock) require.NoError(t, err) return bs } func TestSetAndGetHeader(t *testing.T) { - bs := newTestBlockState(t, nil) + bs := newTestBlockState(t, nil, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(0), @@ -58,7 +58,7 @@ func TestSetAndGetHeader(t *testing.T) { } func TestHasHeader(t *testing.T) { - bs := newTestBlockState(t, nil) + bs := newTestBlockState(t, nil, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(0), @@ -75,7 +75,7 @@ func TestHasHeader(t *testing.T) { } func TestGetBlockByNumber(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) blockHeader := &types.Header{ ParentHash: testGenesisHeader.Hash(), @@ -97,7 +97,7 @@ func TestGetBlockByNumber(t *testing.T) { } func TestAddBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) // Create header header0 := &types.Header{ @@ -160,7 +160,7 @@ func TestAddBlock(t *testing.T) { } func TestGetSlotForBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) expectedSlot := uint64(77) babeHeader := types.NewBabeDigest() @@ -191,7 +191,7 @@ func TestGetSlotForBlock(t *testing.T) { } func TestIsBlockOnCurrentChain(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) currChain, branchChains := AddBlocksToState(t, bs, 3, false) for _, header := range currChain { @@ -214,7 +214,7 @@ func TestIsBlockOnCurrentChain(t *testing.T) { } func TestAddBlock_BlockNumberToHash(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) currChain, branchChains := AddBlocksToState(t, bs, 8, false) bestHash := bs.BestBlockHash() @@ -262,7 +262,7 @@ func TestAddBlock_BlockNumberToHash(t *testing.T) { } func TestFinalization_DeleteBlock(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) AddBlocksToState(t, bs, 5, false) btBefore := bs.bt.DeepCopy() @@ -317,7 +317,7 @@ func TestFinalization_DeleteBlock(t *testing.T) { } func TestGetHashByNumber(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) res, err := bs.GetHashByNumber(big.NewInt(0)) require.NoError(t, err) @@ -344,7 +344,7 @@ func TestGetHashByNumber(t *testing.T) { func TestAddBlock_WithReOrg(t *testing.T) { t.Skip() // TODO: this should be fixed after state refactor PR - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) header1a := &types.Header{ Number: big.NewInt(1), @@ -453,7 +453,7 @@ func TestAddBlock_WithReOrg(t *testing.T) { } func TestAddBlockToBlockTree(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) header := &types.Header{ Number: big.NewInt(1), @@ -473,7 +473,7 @@ func TestAddBlockToBlockTree(t *testing.T) { } func TestNumberIsFinalised(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) + bs := newTestBlockState(t, testGenesisHeader, newTriesEmpty()) fin, err := bs.NumberIsFinalised(big.NewInt(0)) require.NoError(t, err) require.True(t, fin) diff --git a/dot/state/epoch_test.go b/dot/state/epoch_test.go index 52c25386bf..de3467f7a2 100644 --- a/dot/state/epoch_test.go +++ b/dot/state/epoch_test.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" "github.com/ChainSafe/gossamer/lib/keystore" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/require" @@ -28,7 +29,8 @@ var genesisBABEConfig = &types.BabeConfiguration{ func newEpochStateFromGenesis(t *testing.T) *EpochState { db := NewInMemoryDB(t) - s, err := NewEpochStateFromGenesis(db, newTestBlockState(t, nil), genesisBABEConfig) + blockState := newTestBlockState(t, nil, NewTries(trie.NewEmptyTrie())) + s, err := NewEpochStateFromGenesis(db, blockState, genesisBABEConfig) require.NoError(t, err) return s } @@ -184,7 +186,7 @@ func TestEpochState_SetAndGetSlotDuration(t *testing.T) { func TestEpochState_GetEpochFromTime(t *testing.T) { s := newEpochStateFromGenesis(t) - s.blockState = newTestBlockState(t, testGenesisHeader) + s.blockState = newTestBlockState(t, testGenesisHeader, NewTries(trie.NewEmptyTrie())) epochDuration, err := time.ParseDuration( fmt.Sprintf("%dms", diff --git a/dot/state/helpers_test.go b/dot/state/helpers_test.go index b7057ccc62..68d6246905 100644 --- a/dot/state/helpers_test.go +++ b/dot/state/helpers_test.go @@ -8,9 +8,17 @@ import ( "testing" "time" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" "github.com/stretchr/testify/require" ) +func newTriesEmpty() *Tries { + return &Tries{ + rootToTrie: make(map[common.Hash]*trie.Trie), + } +} + // newGenerator creates a new PRNG seeded with the // unix nanoseconds value of the current time. func newGenerator() (prng *rand.Rand) { diff --git a/dot/state/initialize.go b/dot/state/initialize.go index 782ebe89e7..7b14afb6f2 100644 --- a/dot/state/initialize.go +++ b/dot/state/initialize.go @@ -62,14 +62,16 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie return fmt.Errorf("failed to write genesis values to database: %s", err) } + tries := NewTries(t) + // create block state from genesis block - blockState, err := NewBlockStateFromGenesis(db, header, s.Telemetry) + blockState, err := NewBlockStateFromGenesis(db, tries, header, s.Telemetry) if err != nil { return fmt.Errorf("failed to create block state from genesis: %s", err) } // create storage state from genesis trie - storageState, err := NewStorageState(db, blockState, t, pruner.Config{}) + storageState, err := NewStorageState(db, blockState, tries, pruner.Config{}) if err != nil { return fmt.Errorf("failed to create storage state from trie: %s", err) } diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index eedc26d7ec..ee45e4d612 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -41,9 +41,11 @@ func NewOfflinePruner(inputDBPath, prunedDBPath string, bloomSize uint64, return nil, fmt.Errorf("failed to load DB %w", err) } + tries := NewTries(trie.NewEmptyTrie()) + // create blockState state // NewBlockState on pruner execution does not use telemetry - blockState, err := NewBlockState(db, nil) + blockState, err := NewBlockState(db, tries, nil) if err != nil { return nil, fmt.Errorf("failed to create block state: %w", err) } @@ -60,7 +62,7 @@ func NewOfflinePruner(inputDBPath, prunedDBPath string, bloomSize uint64, } // load storage state - storageState, err := NewStorageState(db, blockState, trie.NewEmptyTrie(), pruner.Config{}) + storageState, err := NewStorageState(db, blockState, tries, pruner.Config{}) if err != nil { return nil, fmt.Errorf("failed to create new storage state %w", err) } diff --git a/dot/state/service.go b/dot/state/service.go index 48c03f99b7..75b4ae587f 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -114,9 +114,11 @@ func (s *Service) Start() error { return nil } + tries := NewTries(trie.NewEmptyTrie()) + var err error // create block state - s.Block, err = NewBlockState(s.db, s.Telemetry) + s.Block, err = NewBlockState(s.db, tries, s.Telemetry) if err != nil { return fmt.Errorf("failed to create block state: %w", err) } @@ -136,7 +138,7 @@ func (s *Service) Start() error { } // create storage state - s.Storage, err = NewStorageState(s.db, s.Block, trie.NewEmptyTrie(), pr) + s.Storage, err = NewStorageState(s.db, s.Block, tries, pr) if err != nil { return fmt.Errorf("failed to create storage state: %w", err) } @@ -167,9 +169,6 @@ func (s *Service) Start() error { ", highest number " + num.String() + " and genesis hash " + s.Block.genesisHash.String()) - // Start background goroutine to GC pruned keys. - go s.Storage.pruneStorage(s.closeCh) - return nil } diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 897bb96066..f8cabd4c08 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -289,7 +289,7 @@ func TestService_PruneStorage(t *testing.T) { time.Sleep(1 * time.Second) for _, v := range prunedArr { - tr := serv.Storage.tries.get(v.hash) + tr := serv.Storage.blockState.tries.get(v.hash) require.Nil(t, tr) } } diff --git a/dot/state/storage.go b/dot/state/storage.go index 94772187df..4571c9279b 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -30,7 +30,7 @@ func errTrieDoesNotExist(hash common.Hash) error { // StorageState is the struct that holds the trie, db and lock type StorageState struct { blockState *BlockState - tries *tries + tries *Tries db chaindb.Database sync.RWMutex @@ -41,19 +41,14 @@ type StorageState struct { pruner pruner.Pruner } -// NewStorageState creates a new StorageState backed by the given trie and database located at basePath. +// NewStorageState creates a new StorageState backed by the given block state +// and database located at basePath. func NewStorageState(db chaindb.Database, blockState *BlockState, - t *trie.Trie, onlinePruner pruner.Config) (*StorageState, error) { + tries *Tries, onlinePruner pruner.Config) (*StorageState, error) { if db == nil { return nil, fmt.Errorf("cannot have nil database") } - if t == nil { - return nil, fmt.Errorf("cannot have nil trie") - } - - tries := newTries(t) - storageTable := chaindb.NewTable(db, storagePrefix) var p pruner.Pruner @@ -76,11 +71,6 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, }, nil } -func (s *StorageState) pruneKey(keyHeader *types.Header) { - logger.Tracef("pruning trie, number=%d hash=%s", keyHeader.Number, keyHeader.Hash()) - s.tries.delete(keyHeader.StateRoot) -} - // StoreTrie stores the given trie in the StorageState and writes it to the database func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { root := ts.MustRoot() @@ -314,14 +304,3 @@ func (s *StorageState) LoadCodeHash(hash *common.Hash) (common.Hash, error) { func (s *StorageState) GenerateTrieProof(stateRoot common.Hash, keys [][]byte) ([][]byte, error) { return trie.GenerateProof(stateRoot[:], keys, s.db) } - -func (s *StorageState) pruneStorage(closeCh chan interface{}) { - for { - select { - case key := <-s.blockState.pruneKeyCh: - s.pruneKey(key) - case <-closeCh: - return - } - } -} diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index b455a66231..cc0c756277 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -23,9 +23,11 @@ import ( func newTestStorageState(t *testing.T) *StorageState { db := NewInMemoryDB(t) - bs := newTestBlockState(t, testGenesisHeader) + tries := newTriesEmpty() - s, err := NewStorageState(db, bs, trie.NewEmptyTrie(), pruner.Config{}) + bs := newTestBlockState(t, testGenesisHeader, tries) + + s, err := NewStorageState(db, bs, tries, pruner.Config{}) require.NoError(t, err) return s } @@ -99,7 +101,7 @@ func TestStorage_TrieState(t *testing.T) { time.Sleep(time.Millisecond * 100) // get trie from db - storage.tries.delete(root) + storage.blockState.tries.delete(root) ts3, err := storage.TrieState(&root) require.NoError(t, err) require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash()) @@ -131,19 +133,19 @@ func TestStorage_LoadFromDB(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.delete(root) + storage.blockState.tries.delete(root) data, err := storage.GetStorage(&root, trieKV[0].key) require.NoError(t, err) require.Equal(t, trieKV[0].value, data) - storage.tries.delete(root) + storage.blockState.tries.delete(root) prefixKeys, err := storage.GetKeysWithPrefix(&root, []byte("ke")) require.NoError(t, err) require.Equal(t, 2, len(prefixKeys)) - storage.tries.delete(root) + storage.blockState.tries.delete(root) entries, err := storage.Entries(&root) require.NoError(t, err) @@ -161,7 +163,7 @@ func TestStorage_StoreTrie_NotSyncing(t *testing.T) { err = storage.StoreTrie(ts, nil) require.NoError(t, err) - require.Equal(t, 2, storage.tries.len()) + require.Equal(t, 2, storage.blockState.tries.len()) } func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { @@ -179,16 +181,18 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { "0", )) - blockState, err := NewBlockStateFromGenesis(db, genHeader, telemetryMock) - require.NoError(t, err) - testChildTrie := trie.NewEmptyTrie() testChildTrie.Put([]byte("keyInsidechild"), []byte("voila")) err = genTrie.PutChild([]byte("keyToChild"), testChildTrie) require.NoError(t, err) - storage, err := NewStorageState(db, blockState, genTrie, pruner.Config{}) + tries := NewTries(genTrie) + + blockState, err := NewBlockStateFromGenesis(db, tries, genHeader, telemetryMock) + require.NoError(t, err) + + storage, err := NewStorageState(db, blockState, tries, pruner.Config{}) require.NoError(t, err) trieState, err := runtime.NewTrieState(genTrie) @@ -208,7 +212,7 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.delete(rootHash) + storage.blockState.tries.delete(rootHash) _, err = storage.GetStorageChild(&rootHash, []byte("keyToChild")) require.NoError(t, err) diff --git a/dot/state/tries.go b/dot/state/tries.go index e7afd3dbb1..21342f4d67 100644 --- a/dot/state/tries.go +++ b/dot/state/tries.go @@ -10,13 +10,17 @@ import ( "github.com/ChainSafe/gossamer/lib/trie" ) -type tries struct { +// Tries is a thread safe map of root hash +// to trie. +type Tries struct { rootToTrie map[common.Hash]*trie.Trie mapMutex sync.RWMutex } -func newTries(t *trie.Trie) *tries { - return &tries{ +// NewTries creates a new thread safe map of root hash +// to trie using the trie given as a first trie. +func NewTries(t *trie.Trie) *Tries { + return &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ t.MustHash(): t, }, @@ -25,7 +29,7 @@ func newTries(t *trie.Trie) *tries { // softSet sets the given trie at the given root hash // in the memory map only if it is not already set. -func (t *tries) softSet(root common.Hash, trie *trie.Trie) { +func (t *Tries) softSet(root common.Hash, trie *trie.Trie) { t.mapMutex.Lock() defer t.mapMutex.Unlock() @@ -37,7 +41,7 @@ func (t *tries) softSet(root common.Hash, trie *trie.Trie) { t.rootToTrie[root] = trie } -func (t *tries) delete(root common.Hash) { +func (t *Tries) delete(root common.Hash) { t.mapMutex.Lock() defer t.mapMutex.Unlock() delete(t.rootToTrie, root) @@ -45,7 +49,7 @@ func (t *tries) delete(root common.Hash) { // get retrieves the trie corresponding to the root hash given // from the in-memory thread safe map. -func (t *tries) get(root common.Hash) (tr *trie.Trie) { +func (t *Tries) get(root common.Hash) (tr *trie.Trie) { t.mapMutex.RLock() defer t.mapMutex.RUnlock() return t.rootToTrie[root] @@ -53,7 +57,7 @@ func (t *tries) get(root common.Hash) (tr *trie.Trie) { // len returns the current numbers of tries // stored in the in-memory map. -func (t *tries) len() int { +func (t *Tries) len() int { t.mapMutex.RLock() defer t.mapMutex.RUnlock() return len(t.rootToTrie) diff --git a/dot/state/tries_test.go b/dot/state/tries_test.go index 0a0bc1d865..5e3b90e7c1 100644 --- a/dot/state/tries_test.go +++ b/dot/state/tries_test.go @@ -12,14 +12,14 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_newTries(t *testing.T) { +func Test_NewTries(t *testing.T) { t.Parallel() tr := trie.NewEmptyTrie() - rootToTrie := newTries(tr) + rootToTrie := NewTries(tr) - expectedTries := &tries{ + expectedTries := &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ tr.MustHash(): tr, }, @@ -28,36 +28,36 @@ func Test_newTries(t *testing.T) { assert.Equal(t, expectedTries, rootToTrie) } -func Test_tries_softSet(t *testing.T) { +func Test_Tries_softSet(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash trie *trie.Trie - expectedTries *tries + expectedTries *Tries }{ "set new in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, trie: trie.NewEmptyTrie(), - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: trie.NewEmptyTrie(), }, }, }, "do not override in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, }, root: common.Hash{1, 2, 3}, trie: trie.NewEmptyTrie(), - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, @@ -77,31 +77,31 @@ func Test_tries_softSet(t *testing.T) { } } -func Test_tries_delete(t *testing.T) { +func Test_Tries_delete(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash - expectedTries *tries + expectedTries *Tries }{ "not found": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, "deleted": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, }, root: common.Hash{1, 2, 3}, - expectedTries: &tries{ + expectedTries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, @@ -118,16 +118,16 @@ func Test_tries_delete(t *testing.T) { }) } } -func Test_tries_get(t *testing.T) { +func Test_Tries_get(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries root common.Hash trie *trie.Trie }{ "found in map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: trie.NewTrie(&node.Leaf{ Key: []byte{1, 2, 3}, @@ -141,7 +141,7 @@ func Test_tries_get(t *testing.T) { }, "not found in map": { // similar to not found in database - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, root: common.Hash{1, 2, 3}, @@ -160,20 +160,20 @@ func Test_tries_get(t *testing.T) { } } -func Test_tries_len(t *testing.T) { +func Test_Tries_len(t *testing.T) { t.Parallel() testCases := map[string]struct { - tries *tries + tries *Tries length int }{ "empty map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{}, }, }, "non empty map": { - tries: &tries{ + tries: &Tries{ rootToTrie: map[common.Hash]*trie.Trie{ {1, 2, 3}: {}, }, diff --git a/lib/grandpa/grandpa_test.go b/lib/grandpa/grandpa_test.go index f52c86f7a1..14a60f0714 100644 --- a/lib/grandpa/grandpa_test.go +++ b/lib/grandpa/grandpa_test.go @@ -60,7 +60,8 @@ func newTestState(t *testing.T) *state.Service { t.Cleanup(func() { db.Close() }) _, genTrie, _ := genesis.NewTestGenesisWithTrieAndHeader(t) - block, err := state.NewBlockStateFromGenesis(db, testGenesisHeader, telemetryMock) + tries := state.NewTries(genTrie) + block, err := state.NewBlockStateFromGenesis(db, tries, testGenesisHeader, telemetryMock) require.NoError(t, err) rtCfg := &wasmer.Config{} @@ -862,7 +863,6 @@ func TestFindParentWithNumber(t *testing.T) { p, err := gs.findParentWithNumber(v, 1) require.NoError(t, err) - t.Log(st.Block.BlocktreeAsString()) expected, err := st.Block.GetBlockByNumber(big.NewInt(1)) require.NoError(t, err) From 8db8b2abc798157b3af44d70f2b27e4bec6470bb Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 17 Feb 2022 10:34:31 -0500 Subject: [PATCH 2/3] fix(dot/network): resize bytes slice buffer if needed (#2291) --- dot/network/helpers_test.go | 2 +- dot/network/inbound.go | 4 ++-- dot/network/notifications.go | 4 ++-- dot/network/service.go | 2 +- dot/network/sync.go | 2 +- dot/network/utils.go | 6 ++++-- 6 files changed, 11 insertions(+), 9 deletions(-) diff --git a/dot/network/helpers_test.go b/dot/network/helpers_test.go index f017213b69..6cc1056884 100644 --- a/dot/network/helpers_test.go +++ b/dot/network/helpers_test.go @@ -86,7 +86,7 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream, }() for { - tot, err := readStream(stream, msgBytes) + tot, err := readStream(stream, &msgBytes) if errors.Is(err, io.EOF) { return } else if err != nil { diff --git a/dot/network/inbound.go b/dot/network/inbound.go index dce76de654..0437d09238 100644 --- a/dot/network/inbound.go +++ b/dot/network/inbound.go @@ -17,10 +17,9 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder peer := stream.Conn().RemotePeer() buffer := s.bufPool.Get().(*[]byte) defer s.bufPool.Put(buffer) - msgBytes := *buffer for { - n, err := readStream(stream, msgBytes[:]) + n, err := readStream(stream, buffer) if err != nil { logger.Tracef( "failed to read from stream id %s of peer %s using protocol %s: %s", @@ -32,6 +31,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder // decode message based on message type // stream should always be inbound if it passes through service.readStream + msgBytes := *buffer msg, err := decoder(msgBytes[:n], peer, isInbound(stream)) if err != nil { logger.Tracef("failed to decode message from stream id %s using protocol %s: %s", diff --git a/dot/network/notifications.go b/dot/network/notifications.go index c55b063186..b4d838dcbf 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -431,14 +431,14 @@ func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDe buffer := s.bufPool.Get().(*[]byte) defer s.bufPool.Put(buffer) - msgBytes := *buffer - tot, err := readStream(stream, msgBytes[:]) + tot, err := readStream(stream, buffer) if err != nil { hsC <- &handshakeReader{hs: nil, err: err} return } + msgBytes := *buffer hs, err := decoder(msgBytes[:tot]) if err != nil { s.host.cm.peerSetHandler.ReportPeer(peerset.ReputationChange{ diff --git a/dot/network/service.go b/dot/network/service.go index cc78982699..b0708a8e2d 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -34,7 +34,7 @@ const ( blockAnnounceID = "/block-announces/1" transactionsID = "/transactions/1" - maxMessageSize = 1024 * 63 // 63kb for now + maxMessageSize = 1024 * 64 // 64kb for now ) var ( diff --git a/dot/network/sync.go b/dot/network/sync.go index 4c260dad2f..d21e7a189a 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -58,7 +58,7 @@ func (s *Service) receiveBlockResponse(stream libp2pnetwork.Stream) (*BlockRespo buf := s.blockResponseBuf - n, err := readStream(stream, buf) + n, err := readStream(stream, &buf) if err != nil { return nil, fmt.Errorf("read stream error: %w", err) } diff --git a/dot/network/utils.go b/dot/network/utils.go index c36cdbd817..66683c1db5 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -176,7 +176,7 @@ func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, int, error) { } // readStream reads from the stream into the given buffer, returning the number of bytes read -func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { +func readStream(stream libp2pnetwork.Stream, bufPointer *[]byte) (int, error) { if stream == nil { return 0, errors.New("stream is nil") } @@ -185,6 +185,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { tot int ) + buf := *bufPointer length, bytesRead, err := readLEB128ToUint64(stream, buf[:1]) if err != nil { return bytesRead, fmt.Errorf("failed to read length: %w", err) @@ -195,8 +196,9 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { } if length > uint64(len(buf)) { + extraBytes := int(length) - len(buf) + *bufPointer = append(buf, make([]byte, extraBytes)...) // TODO #2288 use bytes.Buffer instead logger.Warnf("received message with size %d greater than allocated message buffer size %d", length, len(buf)) - return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length) } if length > maxBlockResponseSize { From 9ac66424b21f3d78ff2bd3c19b052e22d6238cf4 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 17 Feb 2022 13:43:16 -0500 Subject: [PATCH 3/3] chore(lib/trie): refactor `trie.go` and fix additional memory problems (#2255) - Linear refactor of functions for easier reading - If not modified, nodes of the previous trie generation are returned, meaning newer copied nodes will be left for garbage collection - Comments revised, reduced or added depending on the case - Split functions into smaller functions --- internal/trie/node/branch.go | 5 + internal/trie/node/branch_encode.go | 6 +- lib/trie/trie.go | 1287 ++++++++++++++++----------- lib/trie/trie_test.go | 246 +++-- 4 files changed, 923 insertions(+), 621 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 2883df29d9..f7e4d28799 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -11,6 +11,11 @@ import ( var _ Node = (*Branch)(nil) +const ( + // ChildrenCapacity is the maximum number of children in a branch node. + ChildrenCapacity = 16 +) + // Branch is a branch in the trie. type Branch struct { // Partial key bytes in nibbles (0 to f in hexadecimal) diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index a723633048..7408482a3e 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -161,7 +161,7 @@ var parallelEncodingRateLimit = make(chan struct{}, parallelLimit) func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (err error) { // Buffered channels since children might be encoded in this // goroutine or another one. - resultsCh := make(chan encodingAsyncResult, len(children)) + resultsCh := make(chan encodingAsyncResult, ChildrenCapacity) for i, child := range children { if isNodeNil(child) || child.Type() == LeafType { @@ -183,7 +183,7 @@ func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (e } currentIndex := 0 - resultBuffers := make([]*bytes.Buffer, len(children)) + resultBuffers := make([]*bytes.Buffer, ChildrenCapacity) for range children { result := <-resultsCh if result.err != nil && err == nil { // only set the first error we get @@ -193,7 +193,7 @@ func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (e resultBuffers[result.index] = result.buffer // write as many completed buffers to the result buffer. - for currentIndex < len(children) && + for currentIndex < ChildrenCapacity && resultBuffers[currentIndex] != nil { bufferSlice := resultBuffers[currentIndex].Bytes() if err == nil && len(bufferSlice) > 0 { diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 8a294c9566..bb9faf7a01 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -16,13 +16,11 @@ import ( // EmptyHash is the empty trie hash. var EmptyHash, _ = NewEmptyTrie().Hash() -// Trie is a Merkle Patricia Trie. -// The zero value is an empty trie with no database. -// Use NewTrie to create a trie that sits on top of a database. +// Trie is a base 16 modified Merkle Patricia trie. type Trie struct { generation uint64 root Node - childTries map[common.Hash]*Trie // Used to store the child tries. + childTries map[common.Hash]*Trie deletedKeys map[common.Hash]struct{} } @@ -41,54 +39,60 @@ func NewTrie(root Node) *Trie { } } -// Snapshot created a copy of the trie. -func (t *Trie) Snapshot() *Trie { - children := make(map[common.Hash]*Trie) - for h, c := range t.childTries { - children[h] = &Trie{ - generation: c.generation + 1, - root: c.root, +// Snapshot creates a copy of the trie. +// Note it does not deep copy the trie, but will +// copy on write as modifications are done on this new trie. +// It does a snapshot of all child tries as well, and resets +// the set of deleted hashes. +func (t *Trie) Snapshot() (newTrie *Trie) { + childTries := make(map[common.Hash]*Trie, len(t.childTries)) + for rootHash, childTrie := range t.childTries { + childTries[rootHash] = &Trie{ + generation: childTrie.generation + 1, + root: childTrie.root.Copy(false), deletedKeys: make(map[common.Hash]struct{}), } } - newTrie := &Trie{ + return &Trie{ generation: t.generation + 1, root: t.root, - childTries: children, + childTries: childTries, deletedKeys: make(map[common.Hash]struct{}), } - - return newTrie } -func (t *Trie) maybeUpdateGeneration(n Node) Node { - if n == nil { - return nil +// updateGeneration is called when the currentNode is from +// an older trie generation (snapshot) so we deep copy the +// node and update the generation on the newer copy. +func updateGeneration(currentNode Node, trieGeneration uint64, + deletedHashes map[common.Hash]struct{}) (newNode Node) { + if currentNode.GetGeneration() == trieGeneration { + panic(fmt.Sprintf( + "current node has the same generation %d as the trie generation, "+ + "make sure the caller properly checks for the node generation to "+ + "be smaller than the trie generation.", trieGeneration)) } - - // Make a copy if the generation is updated. - if n.GetGeneration() < t.generation { - // Insert a new node in the current generation. - const copyChildren = false - newNode := n.Copy(copyChildren) - - newNode.SetGeneration(t.generation) - - // Hash of old nodes should already be computed since it belongs to older generation. - oldNodeHash := n.GetHash() - if len(oldNodeHash) > 0 { - hash := common.BytesToHash(oldNodeHash) - t.deletedKeys[hash] = struct{}{} - } - return newNode + const copyChildren = false + newNode = currentNode.Copy(copyChildren) + newNode.SetGeneration(trieGeneration) + + // The hash of the node from a previous snapshotted trie + // is usually already computed. + deletedHashBytes := currentNode.GetHash() + if len(deletedHashBytes) > 0 { + deletedHash := common.BytesToHash(deletedHashBytes) + deletedHashes[deletedHash] = struct{}{} } - return n + return newNode } // DeepCopy deep copies the trie and returns -// the copy. +// the copy. Note this method is meant to be used +// in tests and should not be used in production +// since it's rather inefficient compared to the copy +// on write mechanism achieved through snapshots. func (t *Trie) DeepCopy() (trieCopy *Trie) { if t == nil { return nil @@ -120,12 +124,13 @@ func (t *Trie) DeepCopy() (trieCopy *Trie) { return trieCopy } -// RootNode returns the root of the trie +// RootNode returns a copy of the root node of the trie. func (t *Trie) RootNode() Node { - return t.root + const copyChildren = false + return t.root.Copy(copyChildren) } -// encodeRoot returns the encoded root of the trie +// encodeRoot writes the encoding of the root node to the buffer. func encodeRoot(root node.Node, buffer node.Buffer) (err error) { if root == nil { _, err = buffer.Write([]byte{0}) @@ -137,7 +142,8 @@ func encodeRoot(root node.Node, buffer node.Buffer) (err error) { return root.Encode(buffer) } -// MustHash returns the hashed root of the trie. It panics if it fails to hash the root node. +// MustHash returns the hashed root of the trie. +// It panics if it fails to hash the root node. func (t *Trie) MustHash() common.Hash { h, err := t.Hash() if err != nil { @@ -147,13 +153,13 @@ func (t *Trie) MustHash() common.Hash { return h } -// Hash returns the hashed root of the trie -func (t *Trie) Hash() (common.Hash, error) { +// Hash returns the hashed root of the trie. +func (t *Trie) Hash() (rootHash common.Hash, err error) { buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() defer pools.EncodingBuffers.Put(buffer) - err := encodeRoot(t.root, buffer) + err = encodeRoot(t.root, buffer) if err != nil { return [32]byte{}, err } @@ -167,675 +173,864 @@ func (t *Trie) Entries() map[string][]byte { return entries(t.root, nil, make(map[string][]byte)) } -func entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { - switch c := current.(type) { - case *node.Branch: - if c.Value != nil { - kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value - } - for i, child := range c.Children { - entries(child, append(prefix, append(c.Key, byte(i))...), kv) - } - case *node.Leaf: - kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value +func entries(parent Node, prefix []byte, kv map[string][]byte) map[string][]byte { + if parent == nil { + return kv + } + + if parent.Type() == node.LeafType { + parentKey := parent.GetKey() + fullKeyNibbles := concatenateSlices(prefix, parentKey) + keyLE := string(codec.NibblesToKeyLE(fullKeyNibbles)) + kv[keyLE] = parent.GetValue() return kv } + // Branch with/without value + branch := parent.(*node.Branch) + + if branch.Value != nil { + fullKeyNibbles := concatenateSlices(prefix, branch.Key) + keyLE := string(codec.NibblesToKeyLE(fullKeyNibbles)) + kv[keyLE] = branch.Value + } + + for i, child := range branch.Children { + childPrefix := concatenateSlices(prefix, branch.Key, intToByteSlice(i)) + entries(child, childPrefix, kv) + } + return kv } -// NextKey returns the next key in the trie in lexicographic order. It returns nil if there is no next key -func (t *Trie) NextKey(key []byte) []byte { - k := codec.KeyLEToNibbles(key) +// NextKey returns the next key in the trie in lexicographic order. +// It returns nil if no next key is found. +func (t *Trie) NextKey(keyLE []byte) (nextKeyLE []byte) { + prefix := []byte(nil) + key := codec.KeyLEToNibbles(keyLE) - next := nextKey(t.root, nil, k) - if next == nil { + nextKey := findNextKey(t.root, prefix, key) + if nextKey == nil { return nil } - return codec.NibblesToKeyLE(next) + nextKeyLE = codec.NibblesToKeyLE(nextKey) + return nextKeyLE } -func nextKey(curr Node, prefix, key []byte) []byte { - switch c := curr.(type) { - case *node.Branch: - fullKey := append(prefix, c.Key...) - var cmp int - if len(key) < len(fullKey) { - if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil - return nil - } +func findNextKey(parent Node, prefix, searchKey []byte) (nextKey []byte) { + if parent == nil { + return nil + } - // the key is lexicographically less than the current node key. return first key available - cmp = 1 - } else { - // if cmp == 1, then node key is lexicographically greater than the key arg - cmp = bytes.Compare(fullKey, key[:len(fullKey)]) - } + if parent.Type() == node.LeafType { + parentLeaf := parent.(*node.Leaf) + return findNextKeyLeaf(parentLeaf, prefix, searchKey) + } - // if length of key arg is less than branch key, - // return key of first child, or key of this branch, - // if it's a branch with value. - if (cmp == 0 && len(key) == len(fullKey)) || cmp == 1 { - if c.Value != nil && bytes.Compare(fullKey, key) > 0 { - return fullKey - } + // Branch + parentBranch := parent.(*node.Branch) + return findNextKeyBranch(parentBranch, prefix, searchKey) +} - for i, child := range c.Children { - if child == nil { - continue - } +func findNextKeyLeaf(leaf *node.Leaf, prefix, searchKey []byte) (nextKey []byte) { + parentLeafKey := leaf.Key + fullKey := concatenateSlices(prefix, parentLeafKey) - next := nextKey(child, append(fullKey, byte(i)), key) - if len(next) != 0 { - return next - } - } - } + if keyIsLexicographicallyBigger(searchKey, fullKey) { + return nil + } - // node key isn't greater than the arg key, continue to iterate - if cmp < 1 && len(key) > len(fullKey) { - idx := key[len(fullKey)] - for i, child := range c.Children[idx:] { - if child == nil { - continue - } - - next := nextKey(child, append(fullKey, byte(i)+idx), key) - if len(next) != 0 { - return next - } - } + return fullKey +} + +func findNextKeyBranch(parentBranch *node.Branch, prefix, searchKey []byte) (nextKey []byte) { + fullKey := concatenateSlices(prefix, parentBranch.Key) + + if bytes.Equal(searchKey, fullKey) { + const startChildIndex = 0 + return findNextKeyChild(parentBranch.Children, startChildIndex, fullKey, searchKey) + } + + if keyIsLexicographicallyBigger(searchKey, fullKey) { + if len(searchKey) < len(fullKey) { + return nil + } else if len(searchKey) > len(fullKey) { + startChildIndex := searchKey[len(fullKey)] + return findNextKeyChild(parentBranch.Children, + startChildIndex, fullKey, searchKey) } - case *node.Leaf: - fullKey := append(prefix, c.Key...) - var cmp int - if len(key) < len(fullKey) { - if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil - return nil - } + } - // the key is lexicographically less than the current node key. return first key available - cmp = 1 - } else { - // if cmp == 1, then node key is lexicographically greater than the key arg - cmp = bytes.Compare(fullKey, key[:len(fullKey)]) + // search key is smaller than full key + if parentBranch.Value != nil { + return fullKey + } + const startChildIndex = 0 + return findNextKeyChild(parentBranch.Children, startChildIndex, + fullKey, searchKey) +} + +func keyIsLexicographicallyBigger(key, key2 []byte) (bigger bool) { + if len(key) < len(key2) { + return bytes.Compare(key, key2[:len(key)]) == 1 + } + return bytes.Compare(key[:len(key2)], key2) != -1 +} + +// findNextKeyChild searches for a next key in the children +// given and returns a next key or nil if no next key is found. +func findNextKeyChild(children [16]node.Node, startIndex byte, + fullKey, key []byte) (nextKey []byte) { + for i := startIndex; i < node.ChildrenCapacity; i++ { + child := children[i] + if child == nil { + continue } - if cmp == 1 { - return append(prefix, c.Key...) + childFullKey := concatenateSlices(fullKey, []byte{i}) + next := findNextKey(child, childFullKey, key) + if len(next) > 0 { + return next } - case nil: - return nil } + return nil } -// Put inserts a key with value into the trie -func (t *Trie) Put(key, value []byte) { - nibblesKey := codec.KeyLEToNibbles(key) - t.tryPut(nibblesKey, value) +// Put inserts a value into the trie at the +// key specified in little Endian format. +func (t *Trie) Put(keyLE, value []byte) { + nibblesKey := codec.KeyLEToNibbles(keyLE) + t.put(nibblesKey, value) } -func (t *Trie) tryPut(key, value []byte) { - t.root = t.insert(t.root, key, node.NewLeaf(nil, value, true, t.generation)) +func (t *Trie) put(key, value []byte) { + nodeToInsert := &node.Leaf{ + Value: value, + Generation: t.generation, + Dirty: true, + } + t.root = t.insert(t.root, key, nodeToInsert) } // insert attempts to insert a key with value into the trie -func (t *Trie) insert(parent Node, key []byte, value Node) Node { - newParent := t.maybeUpdateGeneration(parent) - value.SetGeneration(t.generation) +func (t *Trie) insert(parent Node, key []byte, value Node) (newParent Node) { + // TODO change value node to be value []byte? + value.SetGeneration(t.generation) // just in case it's not set by the caller. - if newParent == nil { + if parent == nil { value.SetKey(key) return value } + // TODO ensure all values have dirty set to true + newParent = parent + if parent.GetGeneration() < t.generation { + newParent = updateGeneration(parent, t.generation, t.deletedKeys) + } + switch newParent.Type() { case node.BranchType, node.BranchWithValueType: - p := newParent.(*node.Branch) - n := t.updateBranch(p, key, value) + parentBranch := newParent.(*node.Branch) + return t.insertInBranch(parentBranch, key, value) + default: + parentLeaf := newParent.(*node.Leaf) + return t.insertInLeaf(parentLeaf, key, value) + } +} - if p != nil && n != nil && n.IsDirty() { - p.SetDirty(true) - } - return n - case node.LeafType: - p := newParent.(*node.Leaf) - // if a value already exists in the trie at this key, overwrite it with the new value - // if the values are the same, don't mark node dirty - if bytes.Equal(p.Key, key) { - if !bytes.Equal(value.(*node.Leaf).Value, p.Value) { - p.Value = value.(*node.Leaf).Value - p.SetDirty(true) - } - return p +func (t *Trie) insertInLeaf(parentLeaf *node.Leaf, key []byte, + value Node) (newParent Node) { + newValue := value.(*node.Leaf).Value + + if bytes.Equal(parentLeaf.Key, key) { + if !bytes.Equal(newValue, parentLeaf.Value) { + parentLeaf.Value = newValue + parentLeaf.SetDirty(true) } + return parentLeaf + } - length := lenCommonPrefix(key, p.Key) + commonPrefixLength := lenCommonPrefix(key, parentLeaf.Key) - // need to convert this leaf into a branch - var newBranchValue []byte - const newBranchDirty = true - br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) - parentKey := p.Key + // Convert the current leaf parent into a branch parent + newBranchParent := &node.Branch{ + Key: key[:commonPrefixLength], + Generation: t.generation, + Dirty: true, + } + parentLeafKey := parentLeaf.Key + + if len(key) == commonPrefixLength { + // key is included in parent leaf key + newBranchParent.Value = newValue + + if len(key) < len(parentLeafKey) { + // Move the current leaf parent as a child to the new branch. + childIndex := parentLeafKey[commonPrefixLength] + parentLeaf.Key = parentLeaf.Key[commonPrefixLength+1:] + parentLeaf.Dirty = true + newBranchParent.Children[childIndex] = parentLeaf + } - // value goes at this branch - if len(key) == length { - br.Value = value.(*node.Leaf).Value - br.SetDirty(true) + return newBranchParent + } - // if we are not replacing previous leaf, then add it as a child to the new branch - if len(parentKey) > len(key) { - p.Key = p.Key[length+1:] - br.Children[parentKey[length]] = p - p.SetDirty(true) - } + value.SetKey(key[commonPrefixLength+1:]) - return br - } + if len(parentLeaf.Key) == commonPrefixLength { + // the key of the parent leaf is at this new branch + newBranchParent.Value = parentLeaf.Value + } else { + // make the leaf a child of the new branch + childIndex := parentLeafKey[commonPrefixLength] + parentLeaf.Key = parentLeaf.Key[commonPrefixLength+1:] + parentLeaf.SetDirty(true) + newBranchParent.Children[childIndex] = parentLeaf + } + childIndex := key[commonPrefixLength] + newBranchParent.Children[childIndex] = value - value.SetKey(key[length+1:]) + return newBranchParent +} - if length == len(p.Key) { - // if leaf's key is covered by this branch, then make the leaf's - // value the value at this branch - br.Value = p.Value - br.Children[key[length]] = value - } else { - // otherwise, make the leaf a child of the branch and update its partial key - p.Key = p.Key[length+1:] - p.SetDirty(true) - br.Children[parentKey[length]] = p - br.Children[key[length]] = value - } +func (t *Trie) insertInBranch(parentBranch *node.Branch, key []byte, value Node) (newParent Node) { + if bytes.Equal(key, parentBranch.Key) { + parentBranch.SetDirty(true) + parentBranch.Value = value.GetValue() + return parentBranch + } - return br - default: - panic("unknown node type: " + fmt.Sprint(newParent.Type())) - } -} - -// updateBranch attempts to add the value node to a branch -// inserts the value node as the branch's child at the index that's -// the first nibble of the key -func (t *Trie) updateBranch(p *node.Branch, key []byte, value Node) (n Node) { - length := lenCommonPrefix(key, p.Key) - - // whole parent key matches - if length == len(p.Key) { - // if node has same key as this branch, then update the value at this branch - if bytes.Equal(key, p.Key) { - p.SetDirty(true) - switch v := value.(type) { - case *node.Branch: - p.Value = v.Value - case *node.Leaf: - p.Value = v.Value + if bytes.HasPrefix(key, parentBranch.Key) { + // key is included in parent branch key + commonPrefixLength := lenCommonPrefix(key, parentBranch.Key) + childIndex := key[commonPrefixLength] + remainingKey := key[commonPrefixLength+1:] + child := parentBranch.Children[childIndex] + + if child == nil { + child = &node.Leaf{ + Key: remainingKey, + Value: value.GetValue(), + Generation: t.generation, + Dirty: true, } - return p - } - - switch c := p.Children[key[length]].(type) { - case *node.Branch, *node.Leaf: - n = t.insert(c, key[length+1:], value) - p.Children[key[length]] = n - n.SetDirty(true) - p.SetDirty(true) - return p - case nil: - // otherwise, add node as child of this branch - value.(*node.Leaf).Key = key[length+1:] - p.Children[key[length]] = value - p.SetDirty(true) - return p + } else { + child = t.insert(child, remainingKey, value) + child.SetDirty(true) } - return n + parentBranch.Children[childIndex] = child + parentBranch.SetDirty(true) + return parentBranch } // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - var newBranchValue []byte - const newBranchDirty = true - br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) + commonPrefixLength := lenCommonPrefix(key, parentBranch.Key) + newParentBranch := &node.Branch{ + Key: key[:commonPrefixLength], + Generation: t.generation, + Dirty: true, + } + parentBranch.SetDirty(true) - parentIndex := p.Key[length] - br.Children[parentIndex] = t.insert(nil, p.Key[length+1:], p) + oldParentIndex := parentBranch.Key[commonPrefixLength] + remainingOldParentKey := parentBranch.Key[commonPrefixLength+1:] + newParentBranch.Children[oldParentIndex] = t.insert(nil, remainingOldParentKey, parentBranch) - if len(key) <= length { - br.Value = value.(*node.Leaf).Value + if len(key) <= commonPrefixLength { + newParentBranch.Value = value.(*node.Leaf).Value } else { - br.Children[key[length]] = t.insert(nil, key[length+1:], value) + childIndex := key[commonPrefixLength] + remainingKey := key[commonPrefixLength+1:] + newParentBranch.Children[childIndex] = t.insert(nil, remainingKey, value) } - br.SetDirty(true) - return br + return newParentBranch } -// LoadFromMap loads the given data into trie -func (t *Trie) LoadFromMap(data map[string]string) error { +// LoadFromMap loads the given data mapping of key to value into the trie. +// The keys are in hexadecimal little Endian encoding and the values +// are hexadecimal encoded. +func (t *Trie) LoadFromMap(data map[string]string) (err error) { for key, value := range data { - keyBytes, err := common.HexToBytes(key) + keyLEBytes, err := common.HexToBytes(key) if err != nil { - return err + return fmt.Errorf("cannot convert key hex to bytes: %w", err) } + valueBytes, err := common.HexToBytes(value) if err != nil { - return err + return fmt.Errorf("cannot convert value hex to bytes: %w", err) } - t.Put(keyBytes, valueBytes) + + t.Put(keyLEBytes, valueBytes) } return nil } -// GetKeysWithPrefix returns all keys in the trie that have the given prefix -func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { - var p []byte - if len(prefix) != 0 { - p = codec.KeyLEToNibbles(prefix) - if p[len(p)-1] == 0 { - p = p[:len(p)-1] - } +// GetKeysWithPrefix returns all keys in little Endian +// format from nodes in the trie that have the given little +// Endian formatted prefix in their key. +func (t *Trie) GetKeysWithPrefix(prefixLE []byte) (keysLE [][]byte) { + var prefixNibbles []byte + if len(prefixLE) > 0 { + prefixNibbles = codec.KeyLEToNibbles(prefixLE) + prefixNibbles = bytes.TrimSuffix(prefixNibbles, []byte{0}) } - return getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) + prefix := []byte{} + key := prefixNibbles + return getKeysWithPrefix(t.root, prefix, key, keysLE) } -func getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { - switch p := parent.(type) { - case *node.Branch: - length := lenCommonPrefix(p.Key, key) +// getKeysWithPrefix returns all keys in little Endian format that have the +// prefix given. The prefix and key byte slices are in nibbles format. +// TODO pass in map of keysLE if order is not needed. +// TODO do all processing on nibbles keys and then convert to LE. +func getKeysWithPrefix(parent Node, prefix, key []byte, + keysLE [][]byte) (newKeysLE [][]byte) { + if parent == nil { + return keysLE + } - if bytes.Equal(p.Key[:length], key) || len(key) == 0 { - // node has prefix, add to list and add all descendant nodes to list - keys = addAllKeys(p, prefix, keys) - return keys - } + if parent.Type() == node.LeafType { + parentLeaf := parent.(*node.Leaf) + return getKeysWithPrefixFromLeaf(parentLeaf, prefix, key, keysLE) + } - if len(key) <= len(p.Key) || length < len(p.Key) { - // no prefixed keys to be found here, return - return keys - } + parentBranch := parent.(*node.Branch) + return getKeysWithPrefixFromBranch(parentBranch, prefix, key, keysLE) +} - key = key[len(p.Key):] - keys = getKeysWithPrefix(p.Children[key[0]], append(append(prefix, p.Key...), key[0]), key[1:], keys) - case *node.Leaf: - length := lenCommonPrefix(p.Key, key) - if bytes.Equal(p.Key[:length], key) || len(key) == 0 { - keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) - } - case nil: - return keys +func getKeysWithPrefixFromLeaf(parent *node.Leaf, prefix, key []byte, + keysLE [][]byte) (newKeysLE [][]byte) { + if len(key) == 0 || bytes.HasPrefix(parent.Key, key) { + fullKeyLE := makeFullKeyLE(prefix, parent.Key) + keysLE = append(keysLE, fullKeyLE) } - return keys + return keysLE } -// addAllKeys appends all keys that are descendants of the parent node to a slice of keys -// it uses the prefix to determine the entire key -func addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { - switch p := parent.(type) { - case *node.Branch: - if p.Value != nil { - keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) - } +func getKeysWithPrefixFromBranch(parent *node.Branch, prefix, key []byte, + keysLE [][]byte) (newKeysLE [][]byte) { + if len(key) == 0 || bytes.HasPrefix(parent.Key, key) { + return addAllKeys(parent, prefix, keysLE) + } - for i, child := range p.Children { - keys = addAllKeys(child, append(append(prefix, p.Key...), byte(i)), keys) - } - case *node.Leaf: - keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) - case nil: - return keys + noPossiblePrefixedKeys := + len(parent.Key) > len(key) && + !bytes.HasPrefix(parent.Key, key) + if noPossiblePrefixedKeys { + return keysLE } - return keys + key = key[len(parent.Key):] + childIndex := key[0] + child := parent.Children[childIndex] + childPrefix := makeChildPrefix(prefix, parent.Key, int(childIndex)) + childKey := key[1:] + return getKeysWithPrefix(child, childPrefix, childKey, keysLE) } -// Get returns the value for key stored in the trie at the corresponding key -func (t *Trie) Get(key []byte) []byte { - keyNibbles := codec.KeyLEToNibbles(key) +// addAllKeys appends all keys of descendant nodes of the parent node +// to the slice of keys given and returns this slice. +// It uses the prefix in nibbles format to determine the full key. +// The slice of keys has its keys formatted in little Endian. +func addAllKeys(parent Node, prefix []byte, keysLE [][]byte) (newKeysLE [][]byte) { + if parent == nil { + return keysLE + } + + if parent.Type() == node.LeafType { + keyLE := makeFullKeyLE(prefix, parent.GetKey()) + keysLE = append(keysLE, keyLE) + return keysLE + } + + // Branches + branchParent := parent.(*node.Branch) + if branchParent.Value != nil { + keyLE := makeFullKeyLE(prefix, branchParent.Key) + keysLE = append(keysLE, keyLE) + } + + for i, child := range branchParent.Children { + childPrefix := makeChildPrefix(prefix, branchParent.Key, i) + keysLE = addAllKeys(child, childPrefix, keysLE) + } + + return keysLE +} + +func makeFullKeyLE(prefix, nodeKey []byte) (fullKeyLE []byte) { + fullKey := concatenateSlices(prefix, nodeKey) + fullKeyLE = codec.NibblesToKeyLE(fullKey) + return fullKeyLE +} + +func makeChildPrefix(branchPrefix, branchKey []byte, + childIndex int) (childPrefix []byte) { + childPrefix = concatenateSlices(branchPrefix, branchKey, intToByteSlice(childIndex)) + return childPrefix +} + +// Get returns the value in the node of the trie +// which matches its key with the key given. +// Note the key argument is given in little Endian format. +func (t *Trie) Get(keyLE []byte) (value []byte) { + keyNibbles := codec.KeyLEToNibbles(keyLE) return retrieve(t.root, keyNibbles) } func retrieve(parent Node, key []byte) (value []byte) { - switch p := parent.(type) { - case *node.Branch: - length := lenCommonPrefix(p.Key, key) + if parent == nil { + return nil + } - // found the value at this node - if bytes.Equal(p.Key, key) || len(key) == 0 { - return p.Value - } + if parent.Type() == node.LeafType { + leaf := parent.(*node.Leaf) + return retrieveFromLeaf(leaf, key) + } - // did not find value - if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { - return nil - } + // Branches + branch := parent.(*node.Branch) + return retrieveFromBranch(branch, key) +} - value = retrieve(p.Children[key[length]], key[length+1:]) - case *node.Leaf: - if bytes.Equal(p.Key, key) { - value = p.Value - } - case nil: +func retrieveFromLeaf(leaf *node.Leaf, key []byte) (value []byte) { + if bytes.Equal(leaf.Key, key) { + return leaf.Value + } + return nil +} + +func retrieveFromBranch(branch *node.Branch, key []byte) (value []byte) { + if len(key) == 0 || bytes.Equal(branch.Key, key) { + return branch.Value + } + + if len(branch.Key) > len(key) && bytes.HasPrefix(branch.Key, key) { return nil } - return value + + commonPrefixLength := lenCommonPrefix(branch.Key, key) + childIndex := key[commonPrefixLength] + childKey := key[commonPrefixLength+1:] + child := branch.Children[childIndex] + return retrieve(child, childKey) } -// ClearPrefixLimit deletes the keys having the prefix till limit reached -func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { +// ClearPrefixLimit deletes the keys having the prefix given in little +// Endian format for up to `limit` keys. It returns the number of deleted +// keys and a boolean indicating if all keys with the prefix were deleted +// within the limit. +func (t *Trie) ClearPrefixLimit(prefixLE []byte, limit uint32) (deleted uint32, allDeleted bool) { if limit == 0 { return 0, false } - p := codec.KeyLEToNibbles(prefix) - if len(p) > 0 && p[len(p)-1] == 0 { - p = p[:len(p)-1] - } + prefix := codec.KeyLEToNibbles(prefixLE) + prefix = bytes.TrimSuffix(prefix, []byte{0}) - l := limit - var allDeleted bool - t.root, _, allDeleted = t.clearPrefixLimit(t.root, p, &limit) - return l - limit, allDeleted + initialLimit := limit + t.root, _, allDeleted = t.clearPrefixLimit(t.root, prefix, &limit) + deleted = initialLimit - limit + return deleted, allDeleted } // clearPrefixLimit deletes the keys having the prefix till limit reached and returns updated trie root node, // true if any node in the trie got updated, and next bool returns true if there is no keys left with prefix. -func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bool, bool) { - curr := t.maybeUpdateGeneration(cn) - - switch c := curr.(type) { - case *node.Branch: - length := lenCommonPrefix(c.Key, prefix) - if length == len(prefix) { - n := t.deleteNodes(c, []byte{}, limit) - if n == nil { - return nil, true, true - } - return n, true, false +// TODO return deleted count and deduce updated from deleted count, do not pass limit as pointer. +func (t *Trie) clearPrefixLimit(parent Node, prefix []byte, limit *uint32) ( + newParent Node, updated bool, allDeleted bool) { + if parent == nil { + return nil, false, true + } + + newParent = parent + if parent.GetGeneration() < t.generation { + newParent = updateGeneration(parent, t.generation, t.deletedKeys) + } + + if newParent.Type() == node.LeafType { + leaf := newParent.(*node.Leaf) + // if prefix is not found, it's also all deleted. + // TODO check this is the same behaviour as in substrate + const allDeleted = true + if bytes.HasPrefix(leaf.Key, prefix) { + *limit-- + return nil, true, allDeleted } + // not modified so return the leaf of the original + // trie generation. The copied leaf newParent will be + // garbage collected. + return parent, false, allDeleted + } - if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { - i := prefix[len(c.Key)] + branch := newParent.(*node.Branch) + newParent, updated, allDeleted = t.clearPrefixLimitBranch(branch, prefix, limit) + if !updated { + // not modified so return the node of the original + // trie generation. The copied newParent will be + // garbage collected. + newParent = parent + } - if c.Children[i] == nil { - // child is already nil at the child index - return c, false, true - } + return newParent, updated, allDeleted +} - c.Children[i] = t.deleteNodes(c.Children[i], []byte{}, limit) +func (t *Trie) clearPrefixLimitBranch(branch *node.Branch, prefix []byte, limit *uint32) ( + newParent Node, updated, allDeleted bool) { + newParent = branch + + if bytes.HasPrefix(branch.Key, prefix) { + updated = true // at least one node will be deleted + nilPrefix := ([]byte)(nil) + // TODO return deleted count to replace updated boolean and update limit + newParent = t.deleteNodesLimit(branch, nilPrefix, limit) + allDeleted = newParent == nil + return newParent, updated, allDeleted + } - c.SetDirty(true) - curr = handleDeletion(c, prefix) + if len(prefix) == len(branch.Key)+1 && + bytes.HasPrefix(branch.Key, prefix[:len(prefix)-1]) { + // Prefix is one the children of the branch + return t.clearPrefixLimitChild(branch, prefix, limit) + } - if c.Children[i] == nil { - return curr, true, true - } - return c, true, false - } + noPrefixForNode := len(prefix) <= len(branch.Key) || + lenCommonPrefix(branch.Key, prefix) < len(branch.Key) + if noPrefixForNode { + updated, allDeleted = false, true + return newParent, updated, allDeleted + } - if len(prefix) <= len(c.Key) || length < len(c.Key) { - // this node doesn't have the prefix, return - return c, false, true - } + childIndex := prefix[len(branch.Key)] + childPrefix := prefix[len(branch.Key)+1:] + child := branch.Children[childIndex] - i := prefix[len(c.Key)] + newParent = branch // mostly just a reminder for the reader + branch.Children[childIndex], updated, allDeleted = t.clearPrefixLimit(child, childPrefix, limit) + if updated { + branch.SetDirty(true) + newParent = handleDeletion(branch, prefix) + } - var wasUpdated, allDeleted bool - c.Children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.Children[i], prefix[len(c.Key)+1:], limit) - if wasUpdated { - c.SetDirty(true) - curr = handleDeletion(c, prefix) - } + return newParent, newParent.IsDirty(), allDeleted +} - return curr, curr.IsDirty(), allDeleted - case *node.Leaf: - length := lenCommonPrefix(c.Key, prefix) - if length == len(prefix) { - *limit-- - return nil, true, true - } - // Prefix not found might be all deleted - return curr, false, true +func (t *Trie) clearPrefixLimitChild(branch *node.Branch, prefix []byte, limit *uint32) ( + newParent Node, updated, allDeleted bool) { + newParent = branch - case nil: - return nil, false, true + childIndex := prefix[len(branch.Key)] + child := branch.Children[childIndex] + + if child == nil { + // TODO ensure this is the same behaviour as in substrate + updated, allDeleted = false, true + return newParent, updated, allDeleted } - return nil, false, true + nilPrefix := ([]byte)(nil) + branch.Children[childIndex] = t.deleteNodesLimit(child, nilPrefix, limit) + branch.SetDirty(true) + + newParent = handleDeletion(branch, prefix) + + updated = true + allDeleted = branch.Children[childIndex] == nil + return newParent, updated, allDeleted } -func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (newNode Node) { +func (t *Trie) deleteNodesLimit(parent Node, prefix []byte, limit *uint32) (newParent Node) { if *limit == 0 { - return cn + return parent } - curr := t.maybeUpdateGeneration(cn) + if parent == nil { + return nil + } - switch c := curr.(type) { - case *node.Leaf: + newParent = parent + if parent.GetGeneration() < t.generation { + newParent = updateGeneration(parent, t.generation, t.deletedKeys) + } + + if newParent.Type() == node.LeafType { *limit-- return nil - case *node.Branch: - if len(c.Key) != 0 { - prefix = append(prefix, c.Key...) - } + } - for i, child := range c.Children { - if child == nil { - continue - } + branch := newParent.(*node.Branch) - c.Children[i] = t.deleteNodes(child, prefix, limit) + fullKey := concatenateSlices(prefix, branch.Key) - c.SetDirty(true) - curr = handleDeletion(c, prefix) - isAllNil := c.NumChildren() == 0 - if isAllNil && c.Value == nil { - curr = nil - } + nilChildren := node.ChildrenCapacity - branch.NumChildren() - if *limit == 0 { - return curr - } + for i, child := range branch.Children { + if child == nil { + continue } - // Delete the current node as well - if c.Value != nil { - *limit-- + branch.Children[i] = t.deleteNodesLimit(child, fullKey, limit) + if branch.Children[i] == nil { + nilChildren++ } - return nil + + branch.SetDirty(true) + newParent = handleDeletion(branch, fullKey) + if nilChildren == node.ChildrenCapacity && + branch.Value == nil { + return nil + } + + if *limit == 0 { + return newParent + } + } + + if branch.Value != nil { + *limit-- } - return curr + return nil } -// ClearPrefix deletes all key-value pairs from the trie where the key starts with the given prefix -func (t *Trie) ClearPrefix(prefix []byte) { - if len(prefix) == 0 { +// ClearPrefix deletes all nodes in the trie for which the key contains the +// prefix given in little Endian format. +func (t *Trie) ClearPrefix(prefixLE []byte) { + if len(prefixLE) == 0 { t.root = nil return } - p := codec.KeyLEToNibbles(prefix) - if len(p) > 0 && p[len(p)-1] == 0 { - p = p[:len(p)-1] - } + prefix := codec.KeyLEToNibbles(prefixLE) + prefix = bytes.TrimSuffix(prefix, []byte{0}) - t.root, _ = t.clearPrefix(t.root, p) + t.root, _ = t.clearPrefix(t.root, prefix) } -func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { - curr := t.maybeUpdateGeneration(cn) - switch c := curr.(type) { - case *node.Branch: - length := lenCommonPrefix(c.Key, prefix) +func (t *Trie) clearPrefix(parent Node, prefix []byte) ( + newParent Node, updated bool) { + if parent == nil { + return nil, false + } - if length == len(prefix) { - // found prefix at this branch, delete it - return nil, true - } + newParent = parent + if parent.GetGeneration() < t.generation { + newParent = updateGeneration(parent, t.generation, t.deletedKeys) + } - // Store the current node and return it, if the trie is not updated. + if bytes.HasPrefix(newParent.GetKey(), prefix) { + return nil, true + } - if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { - // found prefix at child index, delete child - i := prefix[len(c.Key)] + if newParent.Type() == node.LeafType { + // not modified so return the leaf of the original + // trie generation. The copied newParent will be + // garbage collected. + return parent, false + } - if c.Children[i] == nil { - // child is already nil at the child index - return c, false - } + branch := newParent.(*node.Branch) - c.Children[i] = nil - c.SetDirty(true) - curr = handleDeletion(c, prefix) - return curr, true - } + if len(prefix) == len(branch.Key)+1 && + bytes.HasPrefix(branch.Key, prefix[:len(prefix)-1]) { + // Prefix is one of the children of the branch + childIndex := prefix[len(branch.Key)] + child := branch.Children[childIndex] - if len(prefix) <= len(c.Key) || length < len(c.Key) { - // this node doesn't have the prefix, return - return c, false + if child == nil { + // child is already nil at the child index + // node is not modified so return the branch of the original + // trie generation. The copied newParent will be + // garbage collected. + return parent, false } - var wasUpdated bool - i := prefix[len(c.Key)] + branch.Children[childIndex] = nil + branch.SetDirty(true) + newParent = handleDeletion(branch, prefix) + return newParent, true + } - c.Children[i], wasUpdated = t.clearPrefix(c.Children[i], prefix[len(c.Key)+1:]) - if wasUpdated { - c.SetDirty(true) - curr = handleDeletion(c, prefix) - } + noPrefixForNode := len(prefix) <= len(branch.Key) || + lenCommonPrefix(branch.Key, prefix) < len(branch.Key) + if noPrefixForNode { + // not modified so return the branch of the original + // trie generation. The copied newParent will be + // garbage collected. + return parent, false + } + + childIndex := prefix[len(branch.Key)] + childPrefix := prefix[len(branch.Key)+1:] + child := branch.Children[childIndex] - return curr, curr.IsDirty() - case *node.Leaf: - length := lenCommonPrefix(c.Key, prefix) - if length == len(prefix) { + branch.Children[childIndex], updated = t.clearPrefix(child, childPrefix) + if !updated { + // branch not modified so return the branch of the original + // trie generation. The copied newParent will be + // garbage collected. + return parent, false + } + + branch.SetDirty(true) + newParent = handleDeletion(branch, prefix) + return newParent, true +} + +// Delete removes the node of the trie with the key +// matching the key given in little Endian format. +// If no node is found at this key, nothing is deleted. +func (t *Trie) Delete(keyLE []byte) { + key := codec.KeyLEToNibbles(keyLE) + t.root, _ = t.delete(t.root, key) +} + +func (t *Trie) delete(parent Node, key []byte) (newParent Node, deleted bool) { + if parent == nil { + return nil, false + } + + newParent = parent + if parent.GetGeneration() < t.generation { + newParent = updateGeneration(parent, t.generation, t.deletedKeys) + } + + if newParent.Type() == node.LeafType { + newParent = deleteLeaf(newParent, key) + if newParent == nil { return nil, true } - return c, false - case nil: - return nil, false + // The leaf was not deleted so return the original + // parent without its generation updated. + // The copied newParent will be garbage collected. + return parent, false + } + + branch := newParent.(*node.Branch) + newParent, deleted = t.deleteBranch(branch, key) + if !deleted { + // Nothing was deleted so return the original + // parent without its generation updated. + // The copied newParent will be garbage collected. + return parent, false } - // This should never happen. - return nil, false + + return newParent, true } -// Delete removes any existing value for key from the trie. -func (t *Trie) Delete(key []byte) { - k := codec.KeyLEToNibbles(key) - t.root, _ = t.delete(t.root, k) +func deleteLeaf(parent Node, key []byte) (newParent Node) { + if len(key) == 0 || bytes.Equal(key, parent.GetKey()) { + return nil + } + return parent } -func (t *Trie) delete(parent Node, key []byte) (Node, bool) { - // Store the current node and return it, if the trie is not updated. - switch p := t.maybeUpdateGeneration(parent).(type) { - case *node.Branch: +func (t *Trie) deleteBranch(branch *node.Branch, key []byte) (newParent Node, deleted bool) { + if len(key) == 0 || bytes.Equal(branch.Key, key) { + branch.Value = nil + branch.SetDirty(true) + return handleDeletion(branch, key), true + } - length := lenCommonPrefix(p.Key, key) - if bytes.Equal(p.Key, key) || len(key) == 0 { - // found the value at this node - p.Value = nil - p.SetDirty(true) - return handleDeletion(p, key), true - } + commonPrefixLength := lenCommonPrefix(branch.Key, key) + childIndex := key[commonPrefixLength] + childKey := key[commonPrefixLength+1:] + child := branch.Children[childIndex] - n, del := t.delete(p.Children[key[length]], key[length+1:]) - if !del { - // If nothing was deleted then don't copy the path. - // Return the parent without its generation updated. - return parent, false - } + newChild, deleted := t.delete(child, childKey) + if !deleted { + return branch, false + } - p.Children[key[length]] = n - p.SetDirty(true) - n = handleDeletion(p, key) - return n, true - case *node.Leaf: - if bytes.Equal(key, p.Key) || len(key) == 0 { - // Key exists. Delete it. - return nil, true + branch.Children[childIndex] = newChild + branch.SetDirty(true) + newParent = handleDeletion(branch, key) + return newParent, true +} + +// handleDeletion is called when a value is deleted from a branch to handle +// the eventual mutation of the branch depending on its children. +// If the branch has no value and a single child, it will be combined with this child. +// If the branch has a value and no child, it will be changed into a leaf. +func handleDeletion(branch *node.Branch, key []byte) (newNode Node) { + // TODO try to remove key argument just use branch.Key instead? + childrenCount := 0 + firstChildIndex := -1 + for i, child := range branch.Children { + if child == nil { + continue } - // Key doesn't exist, return parent - // without its generation changed - return parent, false - case nil: - return nil, false + if firstChildIndex == -1 { + firstChildIndex = i + } + childrenCount++ + } + + switch { default: - panic(fmt.Sprintf("%T: invalid node: %v (%v)", p, p, key)) - } -} - -// handleDeletion is called when a value is deleted from a branch -// if the updated branch only has 1 child, it should be combined with that child -// if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *node.Branch, key []byte) Node { - var n Node = p - length := lenCommonPrefix(p.Key, key) - bitmap := p.ChildrenBitmap() - - // if branch has no children, just a value, turn it into a leaf - if bitmap == 0 && p.Value != nil { - n = node.NewLeaf(key[:length], p.Value, true, p.Generation) - } else if p.NumChildren() == 1 && p.Value == nil { - // there is only 1 child and no value, combine the child branch with this branch - // find index of child - var i int - for i = 0; i < 16; i++ { - bitmap = bitmap >> 1 - if bitmap == 0 { - break + return branch + case childrenCount == 0 && branch.Value != nil: + commonPrefixLength := lenCommonPrefix(branch.Key, key) + return &node.Leaf{ + Key: key[:commonPrefixLength], + Value: branch.Value, + Dirty: true, + Generation: branch.Generation, + } + case childrenCount == 1 && branch.Value == nil: + childIndex := firstChildIndex + child := branch.Children[firstChildIndex] + + if child.Type() == node.LeafType { + childLeafKey := child.GetKey() + newLeafKey := concatenateSlices(branch.Key, intToByteSlice(childIndex), childLeafKey) + return &node.Leaf{ + Key: newLeafKey, + Value: child.GetValue(), + Dirty: true, + Generation: branch.Generation, } } - child := p.Children[i] - switch c := child.(type) { - case *node.Leaf: - key = append(append(p.Key, []byte{byte(i)}...), c.Key...) - const dirty = true - n = node.NewLeaf( - key, - c.Value, - dirty, - p.Generation, - ) - case *node.Branch: - br := new(node.Branch) - br.Key = append(p.Key, append([]byte{byte(i)}, c.Key...)...) - - // adopt the grandchildren - for i, grandchild := range c.Children { - if grandchild != nil { - br.Children[i] = grandchild - // No need to copy and update the generation - // of the grand children since they are not modified. - } - } + childBranch := child.(*node.Branch) + newBranchKey := concatenateSlices(branch.Key, intToByteSlice(childIndex), childBranch.Key) + newBranch := &node.Branch{ + Key: newBranchKey, + Value: childBranch.Value, + Generation: branch.Generation, + Dirty: true, + } - br.Value = c.Value - br.Generation = p.Generation - n = br - default: - // do nothing + // Adopt the grand-children + for i, grandChild := range childBranch.Children { + if grandChild != nil { + newBranch.Children[i] = grandChild + // No need to copy and update the generation + // of the grand children since they are not modified. + } } - n.SetDirty(true) + return newBranch } - return n } -// lenCommonPrefix returns the length of the common prefix between two keys -func lenCommonPrefix(a, b []byte) int { - var length, min = 0, len(a) - - if len(a) > len(b) { +// lenCommonPrefix returns the length of the +// common prefix between two byte slices. +func lenCommonPrefix(a, b []byte) (length int) { + min := len(a) + if len(b) < min { min = len(b) } - for ; length < min; length++ { + for length = 0; length < min; length++ { if a[length] != b[length] { break } @@ -843,3 +1038,33 @@ func lenCommonPrefix(a, b []byte) int { return length } + +func concatenateSlices(sliceOne, sliceTwo []byte, otherSlices ...[]byte) (concatenated []byte) { + allNil := sliceOne == nil && sliceTwo == nil + totalLength := len(sliceOne) + len(sliceTwo) + + for _, otherSlice := range otherSlices { + allNil = allNil && otherSlice == nil + totalLength += len(otherSlice) + } + + if allNil { + // Return a nil slice instead of an an empty slice + // if all slices are nil. + return nil + } + + concatenated = make([]byte, 0, totalLength) + + concatenated = append(concatenated, sliceOne...) + concatenated = append(concatenated, sliceTwo...) + for _, otherSlice := range otherSlices { + concatenated = append(concatenated, otherSlice...) + } + + return concatenated +} + +func intToByteSlice(n int) (slice []byte) { + return []byte{byte(n)} +} diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index ae65eba37b..bf68735575 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -91,37 +91,18 @@ func Test_Trie_Snapshot(t *testing.T) { assert.Equal(t, expectedTrie, newTrie) } -func Test_Trie_maybeUpdateGeneration(t *testing.T) { +func Test_Trie_updateGeneration(t *testing.T) { t.Parallel() testCases := map[string]struct { - trie *Trie - node Node - newNode Node - copied bool - expectedTrie *Trie + trieGeneration uint64 + node Node + newNode Node + copied bool + expectedDeletedHashes map[common.Hash]struct{} }{ - "nil node": {}, - "same generation": { - trie: &Trie{ - generation: 1, - }, - node: &node.Leaf{ - Generation: 1, - Key: []byte{1}, - }, - newNode: &node.Leaf{ - Generation: 1, - Key: []byte{1}, - }, - expectedTrie: &Trie{ - generation: 1, - }, - }, "trie generation higher and empty hash": { - trie: &Trie{ - generation: 2, - }, + trieGeneration: 2, node: &node.Leaf{ Generation: 1, Key: []byte{1}, @@ -130,16 +111,11 @@ func Test_Trie_maybeUpdateGeneration(t *testing.T) { Generation: 2, Key: []byte{1}, }, - copied: true, - expectedTrie: &Trie{ - generation: 2, - }, + copied: true, + expectedDeletedHashes: map[common.Hash]struct{}{}, }, "trie generation higher and hash": { - trie: &Trie{ - generation: 2, - deletedKeys: map[common.Hash]struct{}{}, - }, + trieGeneration: 2, node: &node.Leaf{ Generation: 1, Key: []byte{1}, @@ -151,16 +127,13 @@ func Test_Trie_maybeUpdateGeneration(t *testing.T) { HashDigest: []byte{1, 2, 3}, }, copied: true, - expectedTrie: &Trie{ - generation: 2, - deletedKeys: map[common.Hash]struct{}{ - { - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 2, 3, - }: {}, - }, + expectedDeletedHashes: map[common.Hash]struct{}{ + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 2, 3, + }: {}, }, }, } @@ -170,12 +143,12 @@ func Test_Trie_maybeUpdateGeneration(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - trie := testCase.trie + deletedHashes := make(map[common.Hash]struct{}) - newNode := trie.maybeUpdateGeneration(testCase.node) + newNode := updateGeneration(testCase.node, testCase.trieGeneration, deletedHashes) assert.Equal(t, testCase.newNode, newNode) - assert.Equal(t, testCase.expectedTrie, trie) + assert.Equal(t, testCase.expectedDeletedHashes, deletedHashes) // Check for deep copy if newNode != nil && testCase.copied { @@ -184,6 +157,19 @@ func Test_Trie_maybeUpdateGeneration(t *testing.T) { } }) } + + t.Run("panic on same generation", func(t *testing.T) { + t.Parallel() + node := &node.Leaf{Generation: 1} + const trieGenration = 1 + assert.PanicsWithValue(t, + "current node has the same generation 1 as the trie generation, "+ + "make sure the caller properly checks for the node generation to "+ + "be smaller than the trie generation.", + func() { + updateGeneration(node, trieGenration, nil) + }) + }) } func Test_Trie_RootNode(t *testing.T) { @@ -863,7 +849,7 @@ func Test_nextKey(t *testing.T) { originalTrie := testCase.trie.DeepCopy() - nextKey := nextKey(testCase.trie.root, nil, testCase.key) + nextKey := findNextKey(testCase.trie.root, nil, testCase.key) assert.Equal(t, testCase.nextKey, nextKey) assert.Equal(t, *originalTrie, testCase.trie) // ensure no mutation @@ -928,7 +914,7 @@ func Test_Trie_Put(t *testing.T) { } } -func Test_Trie_tryPut(t *testing.T) { +func Test_Trie_put(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -1020,7 +1006,7 @@ func Test_Trie_tryPut(t *testing.T) { t.Parallel() trie := testCase.trie - trie.tryPut(testCase.key, testCase.value) + trie.put(testCase.key, testCase.value) assert.Equal(t, testCase.expectedTrie, trie) }) @@ -1074,6 +1060,7 @@ func Test_Trie_insert(t *testing.T) { Key: []byte{}, Value: []byte("leaf"), Generation: 1, + Dirty: true, }, &node.Leaf{Key: []byte{2}}, }, @@ -1236,7 +1223,7 @@ func Test_Trie_insert(t *testing.T) { } } -func Test_Trie_updateBranch(t *testing.T) { +func Test_Trie_insertInBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -1309,6 +1296,7 @@ func Test_Trie_updateBranch(t *testing.T) { &node.Leaf{ Key: []byte{4, 5}, Value: []byte{6}, + Dirty: true, }, }, }, @@ -1346,6 +1334,7 @@ func Test_Trie_updateBranch(t *testing.T) { &node.Leaf{ Key: []byte{6}, Value: []byte{6}, + Dirty: true, }, }, }, @@ -1372,6 +1361,7 @@ func Test_Trie_updateBranch(t *testing.T) { &node.Branch{ Key: []byte{}, Value: []byte{5}, + Dirty: true, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -1403,6 +1393,7 @@ func Test_Trie_updateBranch(t *testing.T) { &node.Branch{ Key: []byte{3}, Value: []byte{5}, + Dirty: true, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -1435,6 +1426,7 @@ func Test_Trie_updateBranch(t *testing.T) { &node.Branch{ Key: []byte{}, Value: []byte{5}, + Dirty: true, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -1451,7 +1443,7 @@ func Test_Trie_updateBranch(t *testing.T) { trie := new(Trie) - newNode := trie.updateBranch(testCase.parent, testCase.key, testCase.value) + newNode := trie.insertInBranch(testCase.parent, testCase.key, testCase.value) assert.Equal(t, testCase.newNode, newNode) assert.Equal(t, new(Trie), trie) // check no mutation @@ -1478,14 +1470,14 @@ func Test_Trie_LoadFromMap(t *testing.T) { "0xa": "0x01", }, errWrapped: hex.ErrLength, - errMessage: "encoding/hex: odd length hex string: 0xa", + errMessage: "cannot convert key hex to bytes: encoding/hex: odd length hex string: 0xa", }, "bad value": { data: map[string]string{ "0x01": "0xa", }, errWrapped: hex.ErrLength, - errMessage: "encoding/hex: odd length hex string: 0xa", + errMessage: "cannot convert value hex to bytes: encoding/hex: odd length hex string: 0xa", }, "load into empty trie": { data: map[string]string{ @@ -2088,8 +2080,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Leaf{ - Key: []byte{1, 2}, - Generation: 1, + Key: []byte{1, 2}, }, allDeleted: true, }, @@ -2104,8 +2095,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Leaf{ - Key: []byte{1}, - Generation: 1, + Key: []byte{1}, }, allDeleted: true, }, @@ -2152,8 +2142,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1, 2}, - Generation: 1, + Key: []byte{1, 2}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, &node.Leaf{Key: []byte{2}}, @@ -2176,8 +2165,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1}, - Generation: 1, + Key: []byte{1}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, &node.Leaf{Key: []byte{2}}, @@ -2200,8 +2188,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1}, - Generation: 1, + Key: []byte{1}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, &node.Leaf{Key: []byte{2}}, @@ -2250,9 +2237,8 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1, 2}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1, 2}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -2274,9 +2260,8 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -2298,9 +2283,8 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { limit: 1, expectedLimit: 1, newParent: &node.Branch{ - Key: []byte{1}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{Key: []byte{1}}, }, @@ -2623,7 +2607,7 @@ func Test_Trie_deleteNodes(t *testing.T) { trie := testCase.trie expectedTrie := *trie.DeepCopy() - newNode := trie.deleteNodes(testCase.parent, testCase.prefix, &testCase.limit) + newNode := trie.deleteNodesLimit(testCase.parent, testCase.prefix, &testCase.limit) assert.Equal(t, testCase.limit, testCase.limit) assert.Equal(t, testCase.newNode, newNode) @@ -2738,8 +2722,7 @@ func Test_Trie_clearPrefix(t *testing.T) { }, prefix: []byte{1, 3}, newParent: &node.Leaf{ - Key: []byte{1, 2}, - Generation: 1, + Key: []byte{1, 2}, }, }, "leaf parent with key smaller than prefix": { @@ -2751,8 +2734,7 @@ func Test_Trie_clearPrefix(t *testing.T) { }, prefix: []byte{1, 2}, newParent: &node.Leaf{ - Key: []byte{1}, - Generation: 1, + Key: []byte{1}, }, }, "branch parent with common prefix": { @@ -2790,9 +2772,8 @@ func Test_Trie_clearPrefix(t *testing.T) { }, prefix: []byte{1, 3}, newParent: &node.Branch{ - Key: []byte{1, 2}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1, 2}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{}, }, @@ -2811,9 +2792,8 @@ func Test_Trie_clearPrefix(t *testing.T) { }, prefix: []byte{1, 2, 3}, newParent: &node.Branch{ - Key: []byte{1}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{}, }, @@ -2832,9 +2812,8 @@ func Test_Trie_clearPrefix(t *testing.T) { }, prefix: []byte{1, 2}, newParent: &node.Branch{ - Key: []byte{1}, - Value: []byte{1}, - Generation: 1, + Key: []byte{1}, + Value: []byte{1}, Children: [16]node.Node{ &node.Leaf{}, }, @@ -3393,3 +3372,96 @@ func Test_lenCommonPrefix(t *testing.T) { }) } } + +func Test_concatenateSlices(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + sliceOne []byte + sliceTwo []byte + otherSlices [][]byte + concatenated []byte + }{ + "two nil slices": {}, + "four nil slices": { + otherSlices: [][]byte{nil, nil}, + }, + "only fourth slice not nil": { + otherSlices: [][]byte{ + nil, + {1}, + }, + concatenated: []byte{1}, + }, + "two empty slices": { + sliceOne: []byte{}, + sliceTwo: []byte{}, + concatenated: []byte{}, + }, + "three empty slices": { + sliceOne: []byte{}, + sliceTwo: []byte{}, + otherSlices: [][]byte{{}}, + concatenated: []byte{}, + }, + "concatenate two first slices": { + sliceOne: []byte{1, 2}, + sliceTwo: []byte{3, 4}, + concatenated: []byte{1, 2, 3, 4}, + }, + + "concatenate four slices": { + sliceOne: []byte{1, 2}, + sliceTwo: []byte{3, 4}, + otherSlices: [][]byte{ + {5, 6}, + {7, 8}, + }, + concatenated: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + concatenated := concatenateSlices(testCase.sliceOne, + testCase.sliceTwo, testCase.otherSlices...) + + assert.Equal(t, testCase.concatenated, concatenated) + }) + } +} + +func Benchmark_concatSlices(b *testing.B) { + const sliceSize = 100000 // 100KB + slice1 := make([]byte, sliceSize) + slice2 := make([]byte, sliceSize) + + // 16993 ns/op 245760 B/op 1 allocs/op + b.Run("direct append", func(b *testing.B) { + for i := 0; i < b.N; i++ { + concatenated := append(slice1, slice2...) + concatenated[0] = 1 + } + }) + + // 16340 ns/op 204800 B/op 1 allocs/op + b.Run("append with pre-allocation", func(b *testing.B) { + for i := 0; i < b.N; i++ { + concatenated := make([]byte, 0, len(slice1)+len(slice2)) + concatenated = append(concatenated, slice1...) + concatenated = append(concatenated, slice2...) + concatenated[0] = 1 + } + }) + + // 16453 ns/op 204800 B/op 1 allocs/op + b.Run("concatenation helper function", func(b *testing.B) { + for i := 0; i < b.N; i++ { + concatenated := concatenateSlices(slice1, slice2) + concatenated[0] = 1 + } + }) +}