Skip to content

Commit

Permalink
Add support for Google Cloud Spanner (hashicorp#3977)
Browse files Browse the repository at this point in the history
  • Loading branch information
sethvargo authored and jefferai committed Feb 15, 2018
1 parent 6420c55 commit 7af2bdc
Show file tree
Hide file tree
Showing 97 changed files with 21,916 additions and 98 deletions.
2 changes: 2 additions & 0 deletions command/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
physMySQL "github.com/hashicorp/vault/physical/mysql"
physPostgreSQL "github.com/hashicorp/vault/physical/postgresql"
physS3 "github.com/hashicorp/vault/physical/s3"
physSpanner "github.com/hashicorp/vault/physical/spanner"
physSwift "github.com/hashicorp/vault/physical/swift"
physZooKeeper "github.com/hashicorp/vault/physical/zookeeper"
)
Expand Down Expand Up @@ -134,6 +135,7 @@ var (
"mysql": physMySQL.NewMySQLBackend,
"postgresql": physPostgreSQL.NewPostgreSQLBackend,
"s3": physS3.NewS3Backend,
"spanner": physSpanner.NewBackend,
"swift": physSwift.NewSwiftBackend,
"zookeeper": physZooKeeper.NewZooKeeperBackend,
}
Expand Down
343 changes: 343 additions & 0 deletions physical/spanner/spanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
package spanner

import (
"fmt"
"os"
"sort"
"strconv"
"strings"
"time"

metrics "github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/physical"
log "github.com/mgutz/logxi/v1"
"google.golang.org/api/iterator"
"google.golang.org/grpc/codes"

"cloud.google.com/go/spanner"
"github.com/pkg/errors"
"golang.org/x/net/context"
)

// Verify Backend satisfies the correct interfaces
var _ physical.Backend = (*Backend)(nil)
var _ physical.Transactional = (*Backend)(nil)

const (
// envDatabase is the name of the environment variable to search for the
// database name.
envDatabase = "GOOGLE_SPANNER_DATABASE"

// envHAEnabled is the name of the environment variable to search for the
// boolean indicating if HA is enabled.
envHAEnabled = "GOOGLE_SPANNER_HA_ENABLED"

// envHATable is the name of the environment variable to search for the table
// name to use for HA.
envHATable = "GOOGLE_SPANNER_HA_TABLE"

// envTable is the name of the environment variable to search for the table
// name.
envTable = "GOOGLE_SPANNER_TABLE"

// defaultTable is the default table name if none is specified.
defaultTable = "Vault"

// defaultHASuffix is the default suffix to apply to the table name if no
// HA table is provided.
defaultHASuffix = "HA"
)

var (
// metricDelete is the key for the metric for measuring a Delete call.
metricDelete = []string{"spanner", "delete"}

// metricGet is the key for the metric for measuring a Get call.
metricGet = []string{"spanner", "get"}

// metricList is the key for the metric for measuring a List call.
metricList = []string{"spanner", "list"}

// metricPut is the key for the metric for measuring a Put call.
metricPut = []string{"spanner", "put"}
)

// Backend implements physical.Backend and describes the steps necessary to
// persist data using Google Cloud Spanner.
type Backend struct {
// database is the name of the database to use for data storage and retrieval.
// This is supplied as part of user configuration.
database string

// table is the name of the table in the database.
table string

// haTable is the name of the table to use for HA in the database.
haTable string

// haEnabled indicates if high availability is enabled. Default: true.
haEnabled bool

// client is the underlying API client for talking to spanner.
client *spanner.Client

// logger and permitPool are internal constructs.
logger log.Logger
permitPool *physical.PermitPool
}

// NewBackend creates a new Google Spanner storage backend with the given
// configuration. This uses the official Golang Cloud SDK and therefore supports
// specifying credentials via envvars, credential files, etc.
func NewBackend(c map[string]string, logger log.Logger) (physical.Backend, error) {
logger.Debug("physical/spanner: configuring backend")

// Database name
database := os.Getenv(envDatabase)
if database == "" {
database = c["database"]
}
if database == "" {
return nil, errors.New("missing database name")
}

// Table name
table := os.Getenv(envTable)
if table == "" {
table = c["table"]
}
if table == "" {
table = defaultTable
}

// HA table name
haTable := os.Getenv(envHATable)
if haTable == "" {
haTable = c["ha_table"]
}
if haTable == "" {
haTable = table + defaultHASuffix
}

// HA configuration
haEnabled := false
haEnabledStr := os.Getenv(envHAEnabled)
if haEnabledStr == "" {
haEnabledStr = c["ha_enabled"]
}
if haEnabledStr != "" {
var err error
haEnabled, err = strconv.ParseBool(haEnabledStr)
if err != nil {
return nil, errwrap.Wrapf("failed to parse HA enabled: {{err}}", err)
}
}

// Max parallel
maxParallel, err := extractInt(c["max_parallel"])
if err != nil {
return nil, errwrap.Wrapf("failed to parse max_parallel: {{err}}", err)
}

logger.Debug("physical/spanner: configuration",
"database", database,
"table", table,
"haEnabled", haEnabled,
"haTable", haTable,
"maxParallel", maxParallel,
)
logger.Debug("physical/spanner: creating client")

ctx := context.Background()
client, err := spanner.NewClient(ctx, database)
if err != nil {
return nil, errwrap.Wrapf("failed to create spanner client: {{err}}", err)
}

return &Backend{
database: database,
table: table,
haEnabled: haEnabled,
haTable: haTable,

client: client,
permitPool: physical.NewPermitPool(maxParallel),
logger: logger,
}, nil
}

// Put creates or updates an entry.
func (b *Backend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince(metricPut, time.Now())

// Pooling
b.permitPool.Acquire()
defer b.permitPool.Release()

// Insert
m := spanner.InsertOrUpdateMap(b.table, map[string]interface{}{
"Key": entry.Key,
"Value": entry.Value,
})
if _, err := b.client.Apply(ctx, []*spanner.Mutation{m}); err != nil {
return errwrap.Wrapf("failed to put data: {{err}}", err)
}
return nil
}

// Get fetches an entry. If there is no entry, this function returns nil.
func (b *Backend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince(metricList, time.Now())

// Pooling
b.permitPool.Acquire()
defer b.permitPool.Release()

// Read
row, err := b.client.Single().ReadRow(ctx, b.table, spanner.Key{key}, []string{"Value"})
if spanner.ErrCode(err) == codes.NotFound {
return nil, nil
}
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("failed to read value for %q: {{err}}", key), err)
}

var value []byte
if err := row.Column(0, &value); err != nil {
return nil, errwrap.Wrapf("failed to decode value into bytes: {{err}}", err)
}

return &physical.Entry{
Key: key,
Value: value,
}, nil
}

// Delete deletes an entry with the given key.
func (b *Backend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince(metricDelete, time.Now())

// Pooling
b.permitPool.Acquire()
defer b.permitPool.Release()

// Delete
m := spanner.Delete(b.table, spanner.Key{key})
if _, err := b.client.Apply(ctx, []*spanner.Mutation{m}); err != nil {
return errwrap.Wrapf("failed to delete key: {{err}}", err)
}

return nil
}

// List enumerates all keys with the given prefix.
func (b *Backend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince(metricList, time.Now())

// Pooling
b.permitPool.Acquire()
defer b.permitPool.Release()

// Sanitize
safeTable := sanitizeTable(b.table)

// List
iter := b.client.Single().Query(ctx, spanner.Statement{
SQL: "SELECT Key FROM " + safeTable + " WHERE STARTS_WITH(Key, @prefix)",
Params: map[string]interface{}{
"prefix": prefix,
},
})
defer iter.Stop()

var keys []string

for {
row, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return nil, errwrap.Wrapf("failed to read row: {{err}}", err)
}

var key string
if err := row.Column(0, &key); err != nil {
return nil, errwrap.Wrapf("failed to decode key into string: {{err}}", err)
}

// The results will include the full prefix (folder) and any deeply-nested
// prefixes (subfolders). Vault expects only the top-most things to be
// included.
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
} else {
// Add truncated 'folder' paths
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
}
}

// Sort because the resulting order is not predictable
sort.Strings(keys)

return keys, nil
}

// Transaction runs multiple entries via a single transaction.
func (b *Backend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
// Quit early if we can
if len(txns) == 0 {
return nil
}

// Build all the ops before taking out the pool
ms := make([]*spanner.Mutation, len(txns))
for i, tx := range txns {
op, key, value := tx.Operation, tx.Entry.Key, tx.Entry.Value

switch op {
case physical.DeleteOperation:
ms[i] = spanner.Delete(b.table, spanner.Key{key})
case physical.PutOperation:
ms[i] = spanner.InsertOrUpdateMap(b.table, map[string]interface{}{
"Key": key,
"Value": value,
})
default:
return fmt.Errorf("unsupported transaction operation: %q", op)
}
}

// Pooling
b.permitPool.Acquire()
defer b.permitPool.Release()

// Transactivate!
if _, err := b.client.Apply(ctx, ms); err != nil {
return errwrap.Wrapf("failed to commit transaction: {{err}}", err)
}

return nil
}

// extractInt is a helper function that takes a string and converts that string
// to an int, but accounts for the empty string.
func extractInt(s string) (int, error) {
if s == "" {
return 0, nil
}
return strconv.Atoi(s)
}

// sanitizeTable attempts to sanitize the table name.
func sanitizeTable(s string) string {
end := strings.IndexRune(s, 0)
if end > -1 {
s = s[:end]
}
return strings.Replace(s, `"`, `""`, -1)
}
Loading

0 comments on commit 7af2bdc

Please sign in to comment.