Skip to content

Commit

Permalink
Add basic merge operator tests
Browse files Browse the repository at this point in the history
Fix various bugs in handling merges during iteration.

Fixes #3
  • Loading branch information
petermattis committed Aug 11, 2018
1 parent 8dd6b3a commit 16bf996
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 76 deletions.
24 changes: 19 additions & 5 deletions db_iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
type dbIterPos int8

const (
dbIterCur dbIterPos = iota
dbIterNext
dbIterPrev
dbIterCur dbIterPos = 0
dbIterNext = 1
dbIterPrev = -1
)

type dbIter struct {
Expand Down Expand Up @@ -64,6 +64,7 @@ func (i *dbIter) findNextEntry() bool {
return false
}
}

return false
}

Expand Down Expand Up @@ -100,6 +101,7 @@ func (i *dbIter) findPrevEntry() bool {
return false
}
}

return false
}

Expand All @@ -114,6 +116,7 @@ func (i *dbIter) mergeNext() bool {
for {
i.iter.Next()
if !i.iter.Valid() {
i.pos = dbIterNext
return true
}
key := i.iter.Key()
Expand Down Expand Up @@ -156,6 +159,7 @@ func (i *dbIter) mergePrev() bool {
for {
i.iter.Prev()
if !i.iter.Valid() {
i.pos = dbIterPrev
return true
}
key := i.iter.Key()
Expand Down Expand Up @@ -223,8 +227,13 @@ func (i *dbIter) Next() bool {
if i.err != nil {
return false
}
if i.pos != dbIterNext {
switch i.pos {
case dbIterCur:
i.iter.NextUserKey()
case dbIterPrev:
i.iter.NextUserKey()
i.iter.NextUserKey()
case dbIterNext:
}
return i.findNextEntry()
}
Expand All @@ -233,8 +242,13 @@ func (i *dbIter) Prev() bool {
if i.err != nil {
return false
}
if i.pos != dbIterPrev {
switch i.pos {
case dbIterCur:
i.iter.PrevUserKey()
case dbIterNext:
i.iter.PrevUserKey()
i.iter.PrevUserKey()
case dbIterPrev:
}
return i.findPrevEntry()
}
Expand Down
78 changes: 49 additions & 29 deletions db_iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"math/rand"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -16,49 +17,68 @@ func TestDBIter(t *testing.T) {
var keys []db.InternalKey
var vals [][]byte

newIter := func(seqNum uint64) *dbIter {
return &dbIter{
cmp: db.DefaultComparer.Compare,
merge: db.DefaultMerger.Merge,
iter: &fakeIter{keys: keys, vals: vals},
seqNum: seqNum,
}
}

datadriven.RunTest(t, "testdata/db_iter", func(d *datadriven.TestData) string {
switch d.Cmd {
case "define":
keys = nil
vals = nil
keys = keys[:0]
vals = vals[:0]
for _, key := range strings.Split(d.Input, "\n") {
j := strings.Index(key, ":")
keys = append(keys, makeIkey(key[:j]))
vals = append(vals, []byte(key[j+1:]))
}
return ""

case "next":
seek := fakeIkey(strings.TrimSpace(d.Input))
iter := &dbIter{
cmp: db.DefaultComparer.Compare,
iter: &fakeIter{keys: keys, vals: vals},
seqNum: seek.SeqNum(),
}

var b bytes.Buffer
for iter.SeekGE([]byte(seek.UserKey)); iter.Valid(); iter.Next() {
fmt.Fprintf(&b, "%s:%s\n", iter.Key(), iter.Value())
case "iter":
if len(d.CmdArgs) != 1 || len(d.CmdArgs[0].Vals) != 1 || d.CmdArgs[0].Key != "seq" {
return fmt.Sprintf("iter seq=<value>\n")
}
if err := iter.Error(); err != nil {
fmt.Fprintf(&b, "err=%v\n", err)
}
return b.String()

case "prev":
seek := fakeIkey(strings.TrimSpace(d.Input))
iter := &dbIter{
cmp: db.DefaultComparer.Compare,
iter: &fakeIter{keys: keys, vals: vals},
seqNum: seek.SeqNum(),
seqNum, err := strconv.Atoi(d.CmdArgs[0].Vals[0])
if err != nil {
return err.Error()
}

iter := newIter(uint64(seqNum))
var b bytes.Buffer
for iter.SeekLT([]byte(seek.UserKey)); iter.Valid(); iter.Prev() {
fmt.Fprintf(&b, "%s:%s\n", iter.Key(), iter.Value())
}
if err := iter.Error(); err != nil {
fmt.Fprintf(&b, "err=%v\n", err)
for _, line := range strings.Split(d.Input, "\n") {
parts := strings.Fields(line)
if len(parts) == 0 {
continue
}
switch parts[0] {
case "seek-ge":
if len(parts) != 2 {
return fmt.Sprintf("seek-ge <key>\n")
}
iter.SeekGE([]byte(strings.TrimSpace(parts[1])))
case "seek-lt":
if len(parts) != 2 {
return fmt.Sprintf("seek-lt <key>\n")
}
iter.SeekLT([]byte(strings.TrimSpace(parts[1])))
case "next":
iter.Next()
case "prev":
iter.Prev()
default:
return fmt.Sprintf("unknown op: %s", parts[0])
}
if iter.Valid() {
fmt.Fprintf(&b, "%s:%s\n", iter.Key(), iter.Value())
} else if err := iter.Error(); err != nil {
fmt.Fprintf(&b, "err=%v\n", err)
} else {
fmt.Fprintf(&b, ".\n")
}
}
return b.String()
}
Expand Down
12 changes: 12 additions & 0 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ func (f *fakeIter) SeekLT(key []byte) {
break
}
}
if f.Valid() {
key := f.keys[f.index]
for ; f.index > 0; f.index-- {
pkey := f.keys[f.index-1]
if db.DefaultComparer.Compare(pkey.UserKey, key.UserKey) < 0 {
break
}
}
}
}

func (f *fakeIter) First() {
Expand All @@ -91,6 +100,9 @@ func (f *fakeIter) NextUserKey() bool {
if f.index == -1 {
return f.Next()
}
if f.index == len(f.keys) {
return false
}
key := f.keys[f.index]
for f.Next() {
if db.DefaultComparer.Compare(key.UserKey, f.Key().UserKey) < 0 {
Expand Down
Loading

0 comments on commit 16bf996

Please sign in to comment.