diff --git a/lib/backend/dynamo/dynamodbbk.go b/lib/backend/dynamo/dynamodbbk.go index ea1dd5535b739..1e8332ef072af 100644 --- a/lib/backend/dynamo/dynamodbbk.go +++ b/lib/backend/dynamo/dynamodbbk.go @@ -363,6 +363,10 @@ func (b *Backend) GetRange(ctx context.Context, startKey []byte, endKey []byte, if len(endKey) == 0 { return nil, trace.BadParameter("missing parameter endKey") } + if limit <= 0 { + limit = backend.DefaultRangeLimit + } + result, err := b.getAllRecords(ctx, startKey, endKey, limit) if err != nil { return nil, trace.Wrap(err) @@ -383,6 +387,7 @@ func (b *Backend) GetRange(ctx context.Context, startKey []byte, endKey []byte, func (b *Backend) getAllRecords(ctx context.Context, startKey []byte, endKey []byte, limit int) (*getResult, error) { var result getResult + // this code is being extra careful here not to introduce endless loop // by some unfortunate series of events for i := 0; i < backend.DefaultRangeLimit/100; i++ { @@ -391,7 +396,9 @@ func (b *Backend) getAllRecords(ctx context.Context, startKey []byte, endKey []b return nil, trace.Wrap(err) } result.records = append(result.records, re.records...) - if len(result.records) >= limit || len(re.lastEvaluatedKey) == 0 { + // If the limit was exceeded or there are no more records to fetch return the current result + // otherwise updated lastEvaluatedKey and proceed with obtaining new records. + if (limit != 0 && len(result.records) >= limit) || len(re.lastEvaluatedKey) == 0 { if len(result.records) == backend.DefaultRangeLimit { b.Warnf("Range query hit backend limit. (this is a bug!) startKey=%q,limit=%d", startKey, backend.DefaultRangeLimit) } @@ -744,12 +751,12 @@ func (b *Backend) getRecords(ctx context.Context, startKey, endKey string, limit // isExpired returns 'true' if the given object (record) has a TTL and // it's due. -func (r *record) isExpired() bool { +func (r *record) isExpired(now time.Time) bool { if r.Expires == nil { return false } expiryDateUTC := time.Unix(*r.Expires, 0).UTC() - return time.Now().UTC().After(expiryDateUTC) + return now.UTC().After(expiryDateUTC) } func removeDuplicates(elements []record) []record { @@ -868,7 +875,7 @@ func (b *Backend) getKey(ctx context.Context, key []byte) (*record, error) { return nil, trace.WrapWithMessage(err, "failed to unmarshal dynamo item %q", string(key)) } // Check if key expired, if expired delete it - if r.isExpired() { + if r.isExpired(b.clock.Now()) { if err := b.deleteKey(ctx, key); err != nil { b.Warnf("Failed deleting expired key %q: %v", key, err) } diff --git a/lib/backend/dynamo/dynamodbbk_test.go b/lib/backend/dynamo/dynamodbbk_test.go index d2b336cd2441f..aba11048f0c64 100644 --- a/lib/backend/dynamo/dynamodbbk_test.go +++ b/lib/backend/dynamo/dynamodbbk_test.go @@ -22,11 +22,12 @@ import ( "testing" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/test" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" ) func TestMain(m *testing.M) { @@ -70,7 +71,7 @@ func TestDynamoDB(t *testing.T) { if err != nil { return nil, nil, trace.Wrap(err) } - clock := clockwork.NewFakeClock() + clock := clockwork.NewFakeClockAt(time.Now()) uut.clock = clock return uut, clock, nil } diff --git a/lib/backend/test/suite.go b/lib/backend/test/suite.go index 195236eba1b6a..2ca3bc04462e6 100644 --- a/lib/backend/test/suite.go +++ b/lib/backend/test/suite.go @@ -22,6 +22,7 @@ import ( "context" "encoding/hex" "errors" + "fmt" "math/rand" "sync" "sync/atomic" @@ -140,7 +141,6 @@ func RunBackendComplianceSuite(t *testing.T, newBackend Constructor) { t.Run("Events", func(t *testing.T) { testEvents(t, newBackend) }) - t.Run("WatchersClose", func(t *testing.T) { testWatchersClose(t, newBackend) }) @@ -156,6 +156,14 @@ func RunBackendComplianceSuite(t *testing.T, newBackend Constructor) { t.Run("Mirror", func(t *testing.T) { testMirror(t, newBackend) }) + + t.Run("FetchLimit", func(t *testing.T) { + testFetchLimit(t, newBackend) + }) + + t.Run("Limit", func(t *testing.T) { + testLimit(t, newBackend) + }) } // RequireItems asserts that the supplied `actual` items collection matches @@ -572,6 +580,72 @@ func testEvents(t *testing.T, newBackend Constructor) { requireEvent(t, watcher, types.OpDelete, item.Key, 2*time.Second) } +// testFetchLimit tests fetch max items size limit. +func testFetchLimit(t *testing.T, newBackend Constructor) { + uut, _, err := newBackend() + require.NoError(t, err) + defer func() { require.NoError(t, uut.Close()) }() + + prefix := MakePrefix() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Allocate 65KB buffer. + buff := make([]byte, 1<<16) + itemsCount := 20 + // Fill the backend with events that total size is greater than 1MB (65KB * 20 > 1MB). + for i := 0; i < itemsCount; i++ { + item := &backend.Item{Key: prefix(fmt.Sprintf("/db/database%d", i)), Value: buff} + _, err = uut.Put(ctx, *item) + require.NoError(t, err) + } + + result, err := uut.GetRange(ctx, prefix("/db"), backend.RangeEnd(prefix("/db")), backend.NoLimit) + require.NoError(t, err) + require.Equal(t, itemsCount, len(result.Items)) +} + +// testLimit tests limit. +func testLimit(t *testing.T, newBackend Constructor) { + uut, clock, err := newBackend() + require.NoError(t, err) + defer func() { require.NoError(t, uut.Close()) }() + + prefix := MakePrefix() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + item := &backend.Item{ + Key: prefix("/db/database_tail_item"), + Value: []byte("data"), + Expires: clock.Now().Add(time.Minute), + } + _, err = uut.Put(ctx, *item) + require.NoError(t, err) + for i := 0; i < 10; i++ { + item := &backend.Item{ + Key: prefix(fmt.Sprintf("/db/database%d", i)), + Value: []byte("data"), + Expires: clock.Now().Add(time.Second * 10), + } + _, err = uut.Put(ctx, *item) + require.NoError(t, err) + } + clock.Advance(time.Second * 20) + + item = &backend.Item{ + Key: prefix("/db/database_head_item"), + Value: []byte("data"), + Expires: clock.Now().Add(time.Minute), + } + _, err = uut.Put(ctx, *item) + require.NoError(t, err) + + result, err := uut.GetRange(ctx, prefix("/db"), backend.RangeEnd(prefix("/db")), 2) + require.NoError(t, err) + require.Equal(t, 2, len(result.Items)) +} + // requireEvent asserts that a given event type with the given key is emitted // by a watcher within the supplied timeout, returning that event for further // inspection if successful.