Skip to content

Commit

Permalink
Use generics to specify the tree value instead of interface{}
Browse files Browse the repository at this point in the history
  • Loading branch information
superfell committed Dec 25, 2021
1 parent b22fcee commit d01f77d
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 234 deletions.
71 changes: 36 additions & 35 deletions art.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import (
// Tree is an Adaptive Radix Tree. Keys are arbitrary byte slices, and the path through the tree
// is the key. Values are stored on Leaves of the tree. The tree is organized in lexicographical
// order of the keys.
type Tree struct {
root node
type Tree[V any] struct {
root node[V]
}

// Put inserts or updates a value in the tree associated with the provided key. Value can be any
// interface value, including nil. key can be an arbitrary byte slice, including the empty slice.
func (a *Tree) Put(key []byte, value interface{}) {
func (a *Tree[V]) Put(key []byte, value V) {
a.root = a.put(a.root, key, value)
}

func (a *Tree) put(n node, key []byte, value interface{}) node {
func (a *Tree[V]) put(n node[V], key []byte, value V) node[V] {
if n == nil {
return newPathLeaf(key, value)
}
Expand Down Expand Up @@ -51,27 +51,28 @@ func (a *Tree) put(n node, key []byte, value interface{}) node {

// Get the value for the provided key. exists is true if the key contains a value in the tree,
// false otherwise. The exists flag can be useful if you are storing nil values in the tree.
func (a *Tree) Get(key []byte) (value interface{}, exists bool) {
func (a *Tree[V]) Get(key []byte) (value V, exists bool) {
var zero V
if a.root == nil {
return nil, false
return zero, false
}
curr := a.root
for {
h := curr.header()
if !bytes.HasPrefix(key, h.path.asSlice()) {
return nil, false
return zero, false
}
key = key[h.path.len:]
if len(key) == 0 {
leaf := curr.valueNode()
if leaf != nil {
return leaf.value, true
}
return nil, false
return zero, false
}
next := curr.getChildNode(key)
if next == nil {
return nil, false
return zero, false
}
curr = *next
key = key[1:]
Expand All @@ -80,14 +81,14 @@ func (a *Tree) Get(key []byte) (value interface{}, exists bool) {

// Delete removes the value associated with the supplied key if it exists. Its okay to
// call Delete with a key that doesn't exist.
func (a *Tree) Delete(key []byte) {
func (a *Tree[V]) Delete(key []byte) {
if a.root == nil {
return
}
a.root = a.delete(a.root, key)
}

func (a *Tree) delete(n node, key []byte) node {
func (a *Tree[V]) delete(n node[V], key []byte) node[V] {
h := n.header()
if !bytes.HasPrefix(key, h.path.asSlice()) {
return n
Expand Down Expand Up @@ -122,18 +123,18 @@ const (
// key value is only valid for the duration of the callback, and it should not be
// modified. If the callback needs access to the key after the callback returns, it
// must copy the key. The tree should not be modified during a callback.
type ConsumerFn func(key []byte, value interface{}) WalkState
//type ConsumerFn[V any] func[V](key []byte, value V) WalkState

// Walk will call the provided callback function with each key/value pair, in key order.
// The callback return value can be used to continue or stop the walk
func (a *Tree) Walk(callback ConsumerFn) {
func (a *Tree[V]) Walk(callback func(key []byte, value V) WalkState) {
if a.root == nil {
return
}
a.walk(a.root, make([]byte, 0, 32), callback)
}

func (a *Tree) walk(n node, prefix []byte, callback ConsumerFn) WalkState {
func (a *Tree[V]) walk(n node[V], prefix []byte, callback func(key []byte, value V) WalkState) WalkState {
h := n.header()
prefix = append(prefix, h.path.asSlice()...)
if h.hasValue {
Expand All @@ -142,7 +143,7 @@ func (a *Tree) walk(n node, prefix []byte, callback ConsumerFn) WalkState {
return Stop
}
}
return n.iterateChildren(func(k byte, cn node) WalkState {
return n.iterateChildren(func(k byte, cn node[V]) WalkState {
return a.walk(cn, append(prefix, k), callback)
})
}
Expand All @@ -154,7 +155,7 @@ func (a *Tree) walk(n node, prefix []byte, callback ConsumerFn) WalkState {
// WalkRange(cb). WalkRange([]byte{1}, nil, cb) will wall all that are equal to or greater than [1]
// WalkRange([]byte{1}, []byte{2},cb) will walk all keys with a prefix of [1].
// The callback return value can be used to continue or stop the walk
func (a *Tree) WalkRange(start []byte, end []byte, callback ConsumerFn) {
func (a *Tree[V]) WalkRange(start []byte, end []byte, callback func(key []byte, value V) WalkState) {
if a.root == nil {
return
}
Expand All @@ -165,7 +166,7 @@ func (a *Tree) WalkRange(start []byte, end []byte, callback ConsumerFn) {
a.walkStart(a.root, make([]byte, 0, 32), keyLimit{start, 0}, cmpEnd, callback)
}

func (a *Tree) walkStart(n node, current []byte, start, end keyLimit, callback ConsumerFn) WalkState {
func (a *Tree[V]) walkStart(n node[V], current []byte, start, end keyLimit, callback func(key []byte, value V)WalkState) WalkState {
h := n.header()
for _, k := range h.path.asSlice() {
start.cmpSegment(k)
Expand All @@ -181,7 +182,7 @@ func (a *Tree) walkStart(n node, current []byte, start, end keyLimit, callback C
return Stop
}
}
return n.iterateChildrenRange(start.minNextKey(), end.stopKey(), func(k byte, cn node) WalkState {
return n.iterateChildrenRange(start.minNextKey(), end.stopKey(), func(k byte, cn node[V]) WalkState {
nextStart, nextEnd := start, end
nextStart.cmpSegment(k)
nextEnd.cmpSegment(k)
Expand Down Expand Up @@ -236,7 +237,7 @@ func compare(a, b byte) int {

// PrettyPrint will generate a compact representation of the state of the tree. Its primary
// use is in diagnostics, or helping to understand how the tree is constructed.
func (a *Tree) PrettyPrint(w io.Writer) {
func (a *Tree[V]) PrettyPrint(w io.Writer) {
if a.root == nil {
io.WriteString(w, "[empty]\n")
return
Expand All @@ -256,7 +257,7 @@ type Stats struct {
}

// Stats returns current statistics about the nodes & keys in the tree.
func (a *Tree) Stats() *Stats {
func (a *Tree[V]) Stats() *Stats {
s := &Stats{}
if a.root == nil {
return s
Expand All @@ -271,32 +272,32 @@ type writer interface {
io.StringWriter
}

type nodeConsumer func(k byte, n node) WalkState
//type nodeConsumer func[V any](k byte, n node[V]) WalkState

type node interface {
type node[V any] interface {
header() nodeHeader
keyPath() *keyPath

canAddChild() bool
addChildNode(key byte, child node)
getChildNode(key []byte) *node
iterateChildren(cb nodeConsumer) WalkState
addChildNode(key byte, child node[V])
getChildNode(key []byte) *node[V]
iterateChildren(cb func(k byte, n node[V]) WalkState) WalkState
// iterateChildrenRange a potential subset of children where start >= key < end
iterateChildrenRange(start, end int, cb nodeConsumer) WalkState
iterateChildrenRange(start, end int, cb func(k byte, n node[V]) WalkState) WalkState

canSetNodeValue() bool
setNodeValue(n *leaf)
valueNode() *leaf
setNodeValue(n *leaf[V])
valueNode() *leaf[V]

// remove the value (or child) for this node, the node can be removed from the tree
// if it returns nil, or it can return a different node instance and the
// tree will be updated to that one (i.e. so that nodes can shrink to
// a smaller type)
removeValue() node
removeValue() node[V]
removeChild(key byte)

grow() node
shrink() node
grow() node[V]
shrink() node[V]

pretty(indent int, dest writer)
stats(s *Stats)
Expand All @@ -316,13 +317,13 @@ type nodeHeader struct {
// splitNodePath will if needed split the supplied node into 2 based on the
// overlap of the key and the nodes compressed path. If the key and the path are the
// same then there's no need to split and the node is returned unaltered.
func splitNodePath(key []byte, n node) (remainingKey []byte, out node) {
func splitNodePath[V any](key []byte, n node[V]) (remainingKey []byte, out node[V]) {
h := n.header()
path := h.path.asSlice()
prefixLen := prefixSize(key, path)
if prefixLen < len(path) {
// need to split into 2
parent := &node4{}
parent := &node4[V]{}
parent.path.assign(path[:prefixLen])
parent.addChildNode(path[prefixLen], n)
// +1 because we consumed a byte for the child key
Expand All @@ -347,7 +348,7 @@ func writeIndent(indent int, w io.Writer) {
w.Write(spaces[:indent])
}

func writeNode(n node, name string, indent int, w writer) {
func writeNode[V any](n node[V], name string, indent int, w writer) {
w.WriteByte('[')
w.WriteString(name)
h := n.header()
Expand All @@ -359,7 +360,7 @@ func writeNode(n node, name string, indent int, w writer) {
} else {
w.WriteByte('\n')
}
n.iterateChildren(func(k byte, child node) WalkState {
n.iterateChildren(func(k byte, child node[V]) WalkState {
writeIndent(indent+2, w)
fmt.Fprintf(w, "0x%02X: ", k)
child.pretty(indent+8, w)
Expand Down
Loading

0 comments on commit d01f77d

Please sign in to comment.