Skip to content

Commit

Permalink
Read all pages when list results are paged (#4983)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishoffman authored Jul 24, 2018
1 parent 6a61077 commit fc1fefd
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 19 deletions.
50 changes: 33 additions & 17 deletions physical/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ import (
"time"

storage "github.com/Azure/azure-sdk-for-go/storage"
log "github.com/hashicorp/go-hclog"

"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/physical"
)

// MaxBlobSize at this time
var MaxBlobSize = 1024 * 1024 * 4
const (
// MaxBlobSize at this time
MaxBlobSize = 1024 * 1024 * 4
// MaxListResults is the current default value, setting explicitly
MaxListResults = 5000
)

// AzureBackend is a physical backend that stores data
// within an Azure blob container.
Expand Down Expand Up @@ -180,22 +183,35 @@ func (a *AzureBackend) List(ctx context.Context, prefix string) ([]string, error
defer metrics.MeasureSince([]string{"azure", "list"}, time.Now())

a.permitPool.Acquire()
list, err := a.container.ListBlobs(storage.ListBlobsParameters{Prefix: prefix})
if err != nil {
// Break early.
a.permitPool.Release()
return nil, err
}
a.permitPool.Release()
defer a.permitPool.Release()

var marker string
keys := []string{}
for _, blob := range list.Blobs {
key := strings.TrimPrefix(blob.Name, prefix)
if i := strings.Index(key, "/"); i == -1 {
keys = append(keys, key)
} else {
keys = strutil.AppendIfMissing(keys, key[:i+1])
for {
list, err := a.container.ListBlobs(storage.ListBlobsParameters{
Prefix: prefix,
Marker: marker,
MaxResults: MaxListResults,
})
if err != nil {
return nil, err
}

for _, blob := range list.Blobs {
key := strings.TrimPrefix(blob.Name, prefix)
if i := strings.Index(key, "/"); i == -1 {
// file
keys = append(keys, key)
} else {
// subdirectory
keys = strutil.AppendIfMissing(keys, key[:i+1])
}
}

if list.NextMarker == "" {
break
}
marker = list.NextMarker
}

sort.Strings(keys)
Expand Down
57 changes: 55 additions & 2 deletions physical/azure/azure_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package azure

import (
"context"
"fmt"
"os"
"strconv"
"testing"
"time"

storage "github.com/Azure/azure-sdk-for-go/storage"
cleanhttp "github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/physical"

storage "github.com/Azure/azure-sdk-for-go/storage"
)

func TestAzureBackend(t *testing.T) {
Expand Down Expand Up @@ -50,3 +51,55 @@ func TestAzureBackend(t *testing.T) {
physical.ExerciseBackend(t, backend)
physical.ExerciseBackend_ListPrefix(t, backend)
}

func TestAzureBackend_ListPaging(t *testing.T) {
if os.Getenv("AZURE_ACCOUNT_NAME") == "" ||
os.Getenv("AZURE_ACCOUNT_KEY") == "" {
t.SkipNow()
}

accountName := os.Getenv("AZURE_ACCOUNT_NAME")
accountKey := os.Getenv("AZURE_ACCOUNT_KEY")

ts := time.Now().UnixNano()
name := fmt.Sprintf("vault-test-%d", ts)

cleanupClient, _ := storage.NewBasicClient(accountName, accountKey)
cleanupClient.HTTPClient = cleanhttp.DefaultPooledClient()

logger := logging.NewVaultLogger(log.Debug)

backend, err := NewAzureBackend(map[string]string{
"container": name,
"accountName": accountName,
"accountKey": accountKey,
}, logger)

defer func() {
blobService := cleanupClient.GetBlobService()
container := blobService.GetContainerReference(name)
container.DeleteIfExists(nil)
}()

if err != nil {
t.Fatalf("err: %s", err)
}

// by default, azure returns 5000 results in a page, load up more than that
for i := 0; i < MaxListResults+100; i++ {
if err := backend.Put(context.Background(), &physical.Entry{
Key: strconv.Itoa(i),
Value: []byte(strconv.Itoa(i)),
}); err != nil {
t.Fatalf("err: %s", err)
}
}

results, err := backend.List(context.Background(), "")
if err != nil {
t.Fatalf("err: %s", err)
}
if len(results) != MaxListResults+100 {
t.Fatalf("expected %d, got %d", MaxListResults+100, len(results))
}
}

0 comments on commit fc1fefd

Please sign in to comment.