diff --git a/btree.go b/btree.go index 6088670..ea685a5 100644 --- a/btree.go +++ b/btree.go @@ -52,6 +52,7 @@ import ( "io" "sort" "strings" + "sync" ) // Item represents a single object in the tree. @@ -76,8 +77,9 @@ var ( // FreeList represents a free list of btree nodes. By default each // BTree has its own FreeList, but multiple BTrees can share the same // FreeList. -// Two Btrees using the same freelist are not safe for concurrent write access. +// Two Btrees using the same freelist are safe for concurrent write access. type FreeList struct { + mu sync.Mutex freelist []*node } @@ -88,20 +90,25 @@ func NewFreeList(size int) *FreeList { } func (f *FreeList) newNode() (n *node) { + f.mu.Lock() index := len(f.freelist) - 1 if index < 0 { + f.mu.Unlock() return new(node) } n = f.freelist[index] f.freelist[index] = nil f.freelist = f.freelist[:index] + f.mu.Unlock() return } func (f *FreeList) freeNode(n *node) { + f.mu.Lock() if len(f.freelist) < cap(f.freelist) { f.freelist = append(f.freelist, n) } + f.mu.Unlock() } // ItemIterator allows callers of Ascend* to iterate in-order over portions of @@ -123,8 +130,8 @@ func NewWithFreeList(degree int, f *FreeList) *BTree { panic("bad degree") } return &BTree{ - degree: degree, - freelist: f, + degree: degree, + cow: ©OnWriteContext{freelist: f}, } } @@ -233,7 +240,34 @@ func (s *children) truncate(index int) { type node struct { items items children children - t *BTree + cow *copyOnWriteContext +} + +func (n *node) mutableFor(cow *copyOnWriteContext) *node { + if n.cow == cow { + return n + } + out := cow.newNode() + if cap(out.items) >= len(n.items) { + out.items = out.items[:len(n.items)] + } else { + out.items = make(items, len(n.items), cap(n.items)) + } + copy(out.items, n.items) + // Copy children + if cap(out.children) >= len(n.children) { + out.children = out.children[:len(n.children)] + } else { + out.children = make(children, len(n.children), cap(n.children)) + } + copy(out.children, n.children) + return out +} + +func (n *node) mutableChild(i int) *node { + c := n.children[i].mutableFor(n.cow) + n.children[i] = c + return c } // split splits the given node at the given index. The current node shrinks, @@ -241,7 +275,7 @@ type node struct { // containing all items/children after it. func (n *node) split(i int) (Item, *node) { item := n.items[i] - next := n.t.newNode() + next := n.cow.newNode() next.items = append(next.items, n.items[i+1:]...) n.items.truncate(i) if len(n.children) > 0 { @@ -257,7 +291,7 @@ func (n *node) maybeSplitChild(i, maxItems int) bool { if len(n.children[i].items) < maxItems { return false } - first := n.children[i] + first := n.mutableChild(i) item, second := first.split(maxItems / 2) n.items.insertAt(i, item) n.children.insertAt(i+1, second) @@ -291,7 +325,7 @@ func (n *node) insert(item Item, maxItems int) Item { return out } } - return n.children[i].insert(item, maxItems) + return n.mutableChild(i).insert(item, maxItems) } // get finds the given key in the subtree and returns it. @@ -369,10 +403,10 @@ func (n *node) remove(item Item, minItems int, typ toRemove) Item { panic("invalid type") } // If we get to here, we have children. - child := n.children[i] - if len(child.items) <= minItems { + if len(n.children[i].items) <= minItems { return n.growChildAndRemove(i, item, minItems, typ) } + child := n.mutableChild(i) // Either we had enough items to begin with, or we've done some // merging/stealing, because we've got enough now and we're ready to return // stuff. @@ -411,10 +445,10 @@ func (n *node) remove(item Item, minItems int, typ toRemove) Item { // whether we're in case 1 or 2), we'll have enough items and can guarantee // that we hit case A. func (n *node) growChildAndRemove(i int, item Item, minItems int, typ toRemove) Item { - child := n.children[i] if i > 0 && len(n.children[i-1].items) > minItems { // Steal from left child - stealFrom := n.children[i-1] + child := n.mutableChild(i) + stealFrom := n.mutableChild(i - 1) stolenItem := stealFrom.items.pop() child.items.insertAt(0, n.items[i-1]) n.items[i-1] = stolenItem @@ -423,7 +457,8 @@ func (n *node) growChildAndRemove(i int, item Item, minItems int, typ toRemove) } } else if i < len(n.items) && len(n.children[i+1].items) > minItems { // steal from right child - stealFrom := n.children[i+1] + child := n.mutableChild(i) + stealFrom := n.mutableChild(i + 1) stolenItem := stealFrom.items.removeAt(0) child.items = append(child.items, n.items[i]) n.items[i] = stolenItem @@ -433,15 +468,15 @@ func (n *node) growChildAndRemove(i int, item Item, minItems int, typ toRemove) } else { if i >= len(n.items) { i-- - child = n.children[i] } + child := n.mutableChild(i) // merge with right child mergeItem := n.items.removeAt(i) mergeChild := n.children.removeAt(i + 1) child.items = append(child.items, mergeItem) child.items = append(child.items, mergeChild.items...) child.children = append(child.children, mergeChild.children...) - n.t.freeNode(mergeChild) + n.cow.freeNode(mergeChild) } return n.remove(item, minItems, typ) } @@ -535,12 +570,53 @@ func (n *node) print(w io.Writer, level int) { // Write operations are not safe for concurrent mutation by multiple // goroutines, but Read operations are. type BTree struct { - degree int - length int - root *node + degree int + length int + root *node + cow *copyOnWriteContext +} + +// copyOnWriteContext pointers determine node ownership... a tree with a write +// context equivalent to a node's write context is allowed to modify that node. +// A tree whose write context does not match a node's is not allowed to modify +// it, and must create a new, writable copy (IE: it's a Clone). +// +// When doing any write operation, we maintain the invariant that the current +// node's context is equal to the context of the tree that requested the write. +// We do this by, before we descend into any node, creating a copy with the +// correct context if the contexts don't match. +// +// Since the node we're currently visiting on any write has the requesting +// tree's context, that node is modifiable in place. Children of that node may +// not share context, but before we descend into them, we'll make a mutable +// copy. +type copyOnWriteContext struct { freelist *FreeList } +// Clone clones the btree, lazily. b2 can be used concurrently with +// with the original tree, including concurrent writes to b and b2. +// +// The internal tree structure of b is marked read-only and shared between b and +// b2. Writes to both b and b2 use copy-on-write logic, creating new nodes +// whenever one of b's original nodes would have been modified. Read operations +// should have no performance degredation. Write operations for both b and b2 +// will initially experience minor slow-downs caused by additional allocs and +// copies due to the aforementioned copy-on-write logic, but should converge to +// the original performance characteristics of the original tree. +func (b *BTree) Clone() (b2 *BTree) { + // Create two entirely new copy-on-write contexts. + // This operation effectively creates three trees: + // the original, shared nodes (old b.cow) + // the new b.cow nodes + // the new out.cow nodes + cow1, cow2 := *b.cow, *b.cow + out := *b + b.cow = &cow1 + out.cow = &cow2 + return &out +} + // maxItems returns the max number of items to allow per node. func (t *BTree) maxItems() int { return t.degree*2 - 1 @@ -552,18 +628,20 @@ func (t *BTree) minItems() int { return t.degree - 1 } -func (t *BTree) newNode() (n *node) { - n = t.freelist.newNode() - n.t = t +func (c *copyOnWriteContext) newNode() (n *node) { + n = c.freelist.newNode() + n.cow = c return } -func (t *BTree) freeNode(n *node) { - // clear to allow GC - n.items.truncate(0) - n.children.truncate(0) - n.t = nil // clear to allow GC - t.freelist.freeNode(n) +func (c *copyOnWriteContext) freeNode(n *node) { + if n.cow == c { + // clear to allow GC + n.items.truncate(0) + n.children.truncate(0) + n.cow = nil + c.freelist.freeNode(n) + } } // ReplaceOrInsert adds the given item to the tree. If an item in the tree @@ -576,16 +654,19 @@ func (t *BTree) ReplaceOrInsert(item Item) Item { panic("nil item being added to BTree") } if t.root == nil { - t.root = t.newNode() + t.root = t.cow.newNode() t.root.items = append(t.root.items, item) t.length++ return nil - } else if len(t.root.items) >= t.maxItems() { - item2, second := t.root.split(t.maxItems() / 2) - oldroot := t.root - t.root = t.newNode() - t.root.items = append(t.root.items, item2) - t.root.children = append(t.root.children, oldroot, second) + } else { + t.root = t.root.mutableFor(t.cow) + if len(t.root.items) >= t.maxItems() { + item2, second := t.root.split(t.maxItems() / 2) + oldroot := t.root + t.root = t.cow.newNode() + t.root.items = append(t.root.items, item2) + t.root.children = append(t.root.children, oldroot, second) + } } out := t.root.insert(item, t.maxItems()) if out == nil { @@ -616,11 +697,12 @@ func (t *BTree) deleteItem(item Item, typ toRemove) Item { if t.root == nil || len(t.root.items) == 0 { return nil } + t.root = t.root.mutableFor(t.cow) out := t.root.remove(item, t.minItems(), typ) if len(t.root.items) == 0 && len(t.root.children) > 0 { oldroot := t.root t.root = t.root.children[0] - t.freeNode(oldroot) + t.cow.freeNode(oldroot) } if out != nil { t.length-- diff --git a/btree_test.go b/btree_test.go index f01394a..5da9d8b 100644 --- a/btree_test.go +++ b/btree_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "reflect" "sort" + "sync" "testing" "time" ) @@ -360,6 +361,50 @@ func BenchmarkInsert(b *testing.B) { } } +func BenchmarkDeleteInsert(b *testing.B) { + b.StopTimer() + insertP := perm(benchmarkTreeSize) + tr := New(*btreeDegree) + for _, item := range insertP { + tr.ReplaceOrInsert(item) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + tr.Delete(insertP[i%benchmarkTreeSize]) + tr.ReplaceOrInsert(insertP[i%benchmarkTreeSize]) + } +} + +func BenchmarkDeleteInsertCloneOnce(b *testing.B) { + b.StopTimer() + insertP := perm(benchmarkTreeSize) + tr := New(*btreeDegree) + for _, item := range insertP { + tr.ReplaceOrInsert(item) + } + tr = tr.Clone() + b.StartTimer() + for i := 0; i < b.N; i++ { + tr.Delete(insertP[i%benchmarkTreeSize]) + tr.ReplaceOrInsert(insertP[i%benchmarkTreeSize]) + } +} + +func BenchmarkDeleteInsertCloneEachTime(b *testing.B) { + b.StopTimer() + insertP := perm(benchmarkTreeSize) + tr := New(*btreeDegree) + for _, item := range insertP { + tr.ReplaceOrInsert(item) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + tr = tr.Clone() + tr.Delete(insertP[i%benchmarkTreeSize]) + tr.ReplaceOrInsert(insertP[i%benchmarkTreeSize]) + } +} + func BenchmarkDelete(b *testing.B) { b.StopTimer() insertP := perm(benchmarkTreeSize) @@ -409,6 +454,30 @@ func BenchmarkGet(b *testing.B) { } } +func BenchmarkGetCloneEachTime(b *testing.B) { + b.StopTimer() + insertP := perm(benchmarkTreeSize) + removeP := perm(benchmarkTreeSize) + b.StartTimer() + i := 0 + for i < b.N { + b.StopTimer() + tr := New(*btreeDegree) + for _, v := range insertP { + tr.ReplaceOrInsert(v) + } + b.StartTimer() + for _, item := range removeP { + tr = tr.Clone() + tr.Get(item) + i++ + if i >= b.N { + return + } + } + } +} + type byInts []Item func (a byInts) Len() int { @@ -561,3 +630,60 @@ func BenchmarkDescendLessOrEqual(b *testing.B) { } } } + +const cloneTestSize = 10000 + +func cloneTest(t *testing.T, b *BTree, start int, p []Item, wg *sync.WaitGroup, trees *[]*BTree) { + t.Logf("Starting new clone at %v", start) + *trees = append(*trees, b) + for i := start; i < cloneTestSize; i++ { + b.ReplaceOrInsert(p[i]) + if i%(cloneTestSize/5) == 0 { + wg.Add(1) + go cloneTest(t, b.Clone(), i+1, p, wg, trees) + } + } + wg.Done() +} + +func TestCloneConcurrentOperations(t *testing.T) { + b := New(*btreeDegree) + trees := []*BTree{} + p := perm(cloneTestSize) + var wg sync.WaitGroup + wg.Add(1) + go cloneTest(t, b, 0, p, &wg, &trees) + wg.Wait() + want := rang(cloneTestSize) + t.Logf("Starting equality checks on %d trees", len(trees)) + for i, tree := range trees { + if !reflect.DeepEqual(want, all(tree)) { + t.Errorf("tree %v mismatch", i) + } + } + t.Log("Removing half from first half") + toRemove := rang(cloneTestSize)[cloneTestSize/2:] + for i := 0; i < len(trees)/2; i++ { + tree := trees[i] + wg.Add(1) + go func() { + for _, item := range toRemove { + tree.Delete(item) + } + wg.Done() + }() + } + wg.Wait() + t.Log("Checking all values again") + for i, tree := range trees { + var wantpart []Item + if i < len(trees)/2 { + wantpart = want[:cloneTestSize/2] + } else { + wantpart = want + } + if got := all(tree); !reflect.DeepEqual(wantpart, got) { + t.Errorf("tree %v mismatch, want %v got %v", i, len(want), len(got)) + } + } +}