diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index 713ae7f1e59a..530a2db633fc 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -1181,6 +1181,113 @@ func TestTrieViewInvalidChildrenExcept(t *testing.T) { require.Empty(view1.childViews) } +func Test_Trie_CommitToParentView_Concurrent(t *testing.T) { + for i := 0; i < 5000; i++ { + dbTrie, err := getBasicDB() + require.NoError(t, err) + require.NotNil(t, dbTrie) + + baseView, err := dbTrie.NewView() + require.NoError(t, err) + + parentView, err := baseView.NewView() + require.NoError(t, err) + err = parentView.Insert(context.Background(), []byte{0}, []byte{0}) + require.NoError(t, err) + + childView1, err := parentView.NewView() + require.NoError(t, err) + err = childView1.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(t, err) + + childView2, err := childView1.NewView() + require.NoError(t, err) + err = childView2.Insert(context.Background(), []byte{2}, []byte{2}) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + require.NoError(t, parentView.CommitToParent(context.Background())) + }() + go func() { + defer wg.Done() + require.NoError(t, childView1.CommitToParent(context.Background())) + }() + go func() { + defer wg.Done() + require.NoError(t, childView2.CommitToParent(context.Background())) + }() + + wg.Wait() + + val0, err := baseView.GetValue(context.Background(), []byte{0}) + require.NoError(t, err) + require.Equal(t, []byte{0}, val0) + + val1, err := baseView.GetValue(context.Background(), []byte{1}) + require.NoError(t, err) + require.Equal(t, []byte{1}, val1) + + val2, err := baseView.GetValue(context.Background(), []byte{2}) + require.NoError(t, err) + require.Equal(t, []byte{2}, val2) + } +} + +func Test_Trie_CommitToParentDB_Concurrent(t *testing.T) { + for i := 0; i < 5000; i++ { + dbTrie, err := getBasicDB() + require.NoError(t, err) + require.NotNil(t, dbTrie) + + parentView, err := dbTrie.NewView() + require.NoError(t, err) + err = parentView.Insert(context.Background(), []byte{0}, []byte{0}) + require.NoError(t, err) + + childView1, err := parentView.NewView() + require.NoError(t, err) + err = childView1.Insert(context.Background(), []byte{1}, []byte{1}) + require.NoError(t, err) + + childView2, err := childView1.NewView() + require.NoError(t, err) + err = childView2.Insert(context.Background(), []byte{2}, []byte{2}) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + require.NoError(t, parentView.CommitToParent(context.Background())) + }() + go func() { + defer wg.Done() + require.NoError(t, childView1.CommitToParent(context.Background())) + }() + go func() { + defer wg.Done() + require.NoError(t, childView2.CommitToParent(context.Background())) + }() + + wg.Wait() + + val0, err := dbTrie.GetValue(context.Background(), []byte{0}) + require.NoError(t, err) + require.Equal(t, []byte{0}, val0) + + val1, err := dbTrie.GetValue(context.Background(), []byte{1}) + require.NoError(t, err) + require.Equal(t, []byte{1}, val1) + + val2, err := dbTrie.GetValue(context.Background(), []byte{2}) + require.NoError(t, err) + require.Equal(t, []byte{2}, val2) + } +} + func Test_Trie_ConcurrentReadWrite(t *testing.T) { require := require.New(t) diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 4f95711e18c3..a1fc676d2b9a 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -27,7 +27,7 @@ const defaultPreallocationSize = 100 var ( ErrCommitted = errors.New("view has been committed") - ErrInvalid = errors.New("the trie this view was based on has changed, rending this view invalid") + ErrInvalid = errors.New("the trie this view was based on has changed, rendering this view invalid") ErrOddLengthWithValue = errors.New( "the underlying db only supports whole number of byte keys, so cannot record changes with odd nibble length", ) @@ -222,10 +222,8 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { // ensure that the view under this one is up-to-date before potentially pulling in nodes from it // getting the Merkle root forces any unupdated nodes to recalculate their ids - if t.parentTrie != nil { - if _, err := t.getParentTrie().GetMerkleRoot(ctx); err != nil { - return err - } + if _, err := t.getParentTrie().GetMerkleRoot(ctx); err != nil { + return err } if err := t.applyChangedValuesToTrie(ctx); err != nil { @@ -567,11 +565,10 @@ func (t *trieView) commitChanges(ctx context.Context, trieToCommit *trieView) er // CommitToParent commits the changes from this view to its parent Trie func (t *trieView) CommitToParent(ctx context.Context) error { - // if we are about to write to the db, then we to hold the commitLock - if t.getParentTrie() == t.db { - t.db.commitLock.Lock() - defer t.db.commitLock.Unlock() - } + // TODO: Only lock the commitlock when the parent is the DB + // TODO: fix concurrency bugs with CommitToParent + t.db.commitLock.Lock() + defer t.db.commitLock.Unlock() t.lock.Lock() defer t.lock.Unlock() @@ -597,7 +594,7 @@ func (t *trieView) commitToParent(ctx context.Context) error { return err } - // overwrite this view with changes from the incoming view + // write this view's changes into its parent if err := t.getParentTrie().commitChanges(ctx, t); err != nil { return err }