Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add configurable caching of credentials to file #38

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 27 additions & 99 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,113 +1,41 @@
/*
* Copyright 2020 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cache

import (
"fmt"
"strings"
"sync"

"github.com/netflix/weep/creds"
"github.com/netflix/weep/errors"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
)

var GlobalCache CredentialCache

type CredentialCache struct {
RoleCredentials map[string]*creds.RefreshableProvider
DefaultRole string
mu sync.RWMutex
}
var GlobalCache WeepCache

func init() {
GlobalCache = CredentialCache{
RoleCredentials: make(map[string]*creds.RefreshableProvider),
}
}

// getCacheSlug returns a string unique to a particular combination of a role and chain of roles to assume.
func getCacheSlug(role string, assume []string) string {
var err error
cacheType := viper.GetString("cache.type")
log.Debugf("initializing %s cache", cacheType)
switch cacheType {
case "memory":
GlobalCache = NewMemoryCache()
case "file":
GlobalCache, err = NewFileCache()
if err != nil {
log.Fatalf("failed to initialize file cache: %v", err)
}
default:
log.Fatal("invalid cache type specified")
}
}

type WeepCache interface {
Get(role string, assumeChain []string) (*creds.RefreshableProvider, error)
GetOrSet(client *creds.Client, role string, region string, assumeChain []string) (*creds.RefreshableProvider, error)
SetDefault(client *creds.Client, role string, region string, assumeChain []string) error
GetDefault() (*creds.RefreshableProvider, error)
}

// getSlug returns a string unique to a particular combination of a role and chain of roles to assume.
func getSlug(role string, assume []string) string {
elements := append([]string{role}, assume...)
return strings.Join(elements, "/")
}

func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) {
log.WithFields(log.Fields{
"role": role,
"assumeChain": assumeChain,
}).Info("retrieving credentials")
c, ok := cc.get(getCacheSlug(role, assumeChain))
if ok {
log.Debugf("found credentials for %s in cache", role)
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
}

func (cc *CredentialCache) GetOrSet(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := cc.Get(role, assumeChain)
if err == nil {
return c, nil
}
log.Debugf("no credentials for %s in cache, creating", role)

c, err = cc.set(client, role, region, assumeChain)
if err != nil {
return nil, err
}

return c, nil
}

func (cc *CredentialCache) SetDefault(client *creds.Client, role, region string, assumeChain []string) error {
_, err := cc.set(client, role, region, assumeChain)
if err != nil {
return err
}
cc.DefaultRole = getCacheSlug(role, assumeChain)
return nil
}

func (cc *CredentialCache) GetDefault() (*creds.RefreshableProvider, error) {
if cc.DefaultRole == "" {
return nil, errors.NoDefaultRoleSet
}
c, ok := cc.get(cc.DefaultRole)
if ok {
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
}

func (cc *CredentialCache) get(slug string) (*creds.RefreshableProvider, bool) {
cc.mu.RLock()
defer cc.mu.RUnlock()
c, ok := cc.RoleCredentials[slug]
return c, ok
}

func (cc *CredentialCache) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false)
if err != nil {
return nil, fmt.Errorf("could not generate creds: %w", err)
}
cc.mu.Lock()
defer cc.mu.Unlock()
cc.RoleCredentials[getCacheSlug(role, assumeChain)] = c
return c, nil
}
128 changes: 128 additions & 0 deletions cache/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package cache

import (
"encoding/json"
"fmt"

"github.com/boltdb/bolt"
"github.com/netflix/weep/creds"
"github.com/netflix/weep/errors"
log "github.com/sirupsen/logrus"
)

const BUCKET = "credentials"

type FileDB struct {
db *bolt.DB
}

func NewFileCache() (*FileDB, error) {
db, err := bolt.Open("weep.db", 0600, nil)
if err != nil {
return nil, err
}

fdb := &FileDB{
db: db,
}
err = fdb.setup()
if err != nil {
return nil, err
}

return fdb, nil
}

func (f *FileDB) setup() error {
err := f.db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(BUCKET))
if err != nil {
return err
}
return nil
})
return err
}

func (f *FileDB) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) {
log.WithFields(log.Fields{
"role": role,
"assumeChain": assumeChain,
"cacheType": "file",
}).Info("retrieving credentials")
c, err := f.get(getSlug(role, assumeChain))
if err != nil {
return nil, errors.NoCredentialsFoundInCache
}
return c, nil
}

func (f *FileDB) GetOrSet(client *creds.Client, role string, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := f.Get(role, assumeChain)
if err == nil {
return c, nil
}
log.Debugf("no credentials for %s in cache, creating", role)

c, err = f.set(client, role, region, assumeChain)
if err != nil {
return nil, err
}

return c, nil
}

func (f *FileDB) SetDefault(client *creds.Client, role string, region string, assumeChain []string) error {
// TODO
return nil
}

func (f *FileDB) GetDefault() (*creds.RefreshableProvider, error) {
// TODO
return nil, nil
}

func (f *FileDB) get(slug string) (*creds.RefreshableProvider, error) {
credentials := &creds.RefreshableProvider{}
err := f.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(BUCKET))
result := b.Get([]byte(slug))
err := json.Unmarshal(result, credentials)
if err != nil {
return nil
}
return nil
})
if err != nil {
return credentials, err
}
err = credentials.EnsureRefreshed()
if err != nil {
return credentials, err
}
return credentials, nil
}

func (f *FileDB) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false)
if err != nil {
return nil, fmt.Errorf("could not generate creds: %w", err)
}
data, err := json.Marshal(c)
slug := getSlug(role, assumeChain)
if err != nil {
return nil, fmt.Errorf("could not marshal creds: %w", err)
}
err = f.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(BUCKET))
err := b.Put([]byte(slug), data)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return c, nil
}
88 changes: 88 additions & 0 deletions cache/memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package cache

import (
"fmt"
"sync"

"github.com/netflix/weep/creds"
"github.com/netflix/weep/errors"
log "github.com/sirupsen/logrus"
)

type InMemory struct {
RoleCredentials map[string]*creds.RefreshableProvider
DefaultRole string
mu sync.RWMutex
}

func NewMemoryCache() *InMemory {
return &InMemory{
RoleCredentials: make(map[string]*creds.RefreshableProvider),
}
}

func (cc *InMemory) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) {
log.WithFields(log.Fields{
"role": role,
"assumeChain": assumeChain,
}).Info("retrieving credentials")
c, ok := cc.get(getSlug(role, assumeChain))
if ok {
log.Debugf("found credentials for %s in cache", role)
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
}

func (cc *InMemory) GetOrSet(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := cc.Get(role, assumeChain)
if err == nil {
return c, nil
}
log.Debugf("no credentials for %s in cache, creating", role)

c, err = cc.set(client, role, region, assumeChain)
if err != nil {
return nil, err
}

return c, nil
}

func (cc *InMemory) SetDefault(client *creds.Client, role, region string, assumeChain []string) error {
_, err := cc.set(client, role, region, assumeChain)
if err != nil {
return err
}
cc.DefaultRole = getSlug(role, assumeChain)
return nil
}

func (cc *InMemory) GetDefault() (*creds.RefreshableProvider, error) {
if cc.DefaultRole == "" {
return nil, errors.NoDefaultRoleSet
}
c, ok := cc.get(cc.DefaultRole)
if ok {
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
}

func (cc *InMemory) get(slug string) (*creds.RefreshableProvider, bool) {
cc.mu.RLock()
defer cc.mu.RUnlock()
c, ok := cc.RoleCredentials[slug]
return c, ok
}

func (cc *InMemory) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false)
if err != nil {
return nil, fmt.Errorf("could not generate creds: %w", err)
}
cc.mu.Lock()
defer cc.mu.Unlock()
cc.RoleCredentials[getSlug(role, assumeChain)] = c
return c, nil
}
8 changes: 4 additions & 4 deletions cache/cache_test.go → cache/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestCredentialCache_Get(t *testing.T) {

for i, tc := range cases {
t.Logf("test case %d: %s", i, tc.Description)
testCache := CredentialCache{
testCache := InMemory{
RoleCredentials: tc.CacheContents,
}
actualResult, actualError := testCache.Get(tc.Role, tc.AssumeChain)
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestCredentialCache_GetDefault(t *testing.T) {

for i, tc := range cases {
t.Logf("test case %d: %s", i, tc.Description)
testCache := CredentialCache{
testCache := InMemory{
RoleCredentials: tc.CacheContents,
DefaultRole: tc.DefaultRole,
}
Expand All @@ -183,7 +183,7 @@ func TestCredentialCache_GetDefault(t *testing.T) {
}

func TestCredentialCache_SetDefault(t *testing.T) {
testCache := CredentialCache{
testCache := InMemory{
RoleCredentials: map[string]*creds.RefreshableProvider{},
}
expectedRole := "a"
Expand Down Expand Up @@ -255,7 +255,7 @@ func TestCredentialCache_GetOrSet(t *testing.T) {

for i, tc := range cases {
t.Logf("test case %d: %s", i, tc.Description)
testCache := CredentialCache{
testCache := InMemory{
RoleCredentials: tc.CacheContents,
}
client, err := creds.GetTestClient(creds.ConsolemeCredentialResponseType{
Expand Down
Loading