Skip to content

Commit

Permalink
Add tests for concurrent AccessControl store usage (#286)
Browse files Browse the repository at this point in the history
* refactor(accesscontrol): use interface for AccessStore cache

* refactor(accesscontrol): early return when cache is disabled

* test(accesscontrol): add failing unit test

* test(accesscontrol): skip failing test
  • Loading branch information
aruiz14 authored Oct 8, 2024
1 parent 99e479b commit 5c1a562
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 15 deletions.
40 changes: 25 additions & 15 deletions pkg/accesscontrol/access_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ type roleRevisions interface {
roleRevision(string, string) string
}

// accessStoreCache is a subset of the methods implemented by LRUExpireCache
type accessStoreCache interface {
Add(key interface{}, value interface{}, ttl time.Duration)
Get(key interface{}) (interface{}, bool)
Remove(key interface{})
}

type AccessStore struct {
usersPolicyRules policyRules
groupsPolicyRules policyRules
roles roleRevisions
cache *cache.LRUExpireCache
cache accessStoreCache
}

type roleKey struct {
Expand All @@ -56,26 +63,29 @@ func NewAccessStore(ctx context.Context, cacheResults bool, rbac v1.Interface) *
}

func (l *AccessStore) AccessFor(user user.Info) *AccessSet {
var cacheKey string
if l.cache != nil {
cacheKey = l.CacheKey(user)
val, ok := l.cache.Get(cacheKey)
if ok {
as, _ := val.(*AccessSet)
return as
}
if l.cache == nil {
return l.newAccessSet(user)
}

cacheKey := l.CacheKey(user)

if val, ok := l.cache.Get(cacheKey); ok {
as, _ := val.(*AccessSet)
return as
}

result := l.newAccessSet(user)
result.ID = cacheKey
l.cache.Add(cacheKey, result, 24*time.Hour)

return result
}

func (l *AccessStore) newAccessSet(user user.Info) *AccessSet {
result := l.usersPolicyRules.get(user.GetName())
for _, group := range user.GetGroups() {
result.Merge(l.groupsPolicyRules.get(group))
}

if l.cache != nil {
result.ID = cacheKey
l.cache.Add(cacheKey, result, 24*time.Hour)
}

return result
}

Expand Down
77 changes: 77 additions & 0 deletions pkg/accesscontrol/access_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package accesscontrol
import (
"fmt"
"slices"
"sync"
"testing"
"time"

appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -274,6 +276,81 @@ func TestAccessStore_AccessFor(t *testing.T) {
}
}

type spyCache struct {
accessStoreCache

mu sync.Mutex
setCalls map[any]int
}

func (c *spyCache) Add(k interface{}, v interface{}, ttl time.Duration) {
defer c.observeAdd(k)

time.Sleep(1 * time.Millisecond) // allow other routines to wake up, simulating heavy load
c.accessStoreCache.Add(k, v, ttl)
}

func (c *spyCache) observeAdd(k interface{}) {
c.mu.Lock()
defer c.mu.Unlock()

if c.setCalls == nil {
c.setCalls = make(map[any]int)
}
c.setCalls[k]++
}

func TestAccessStore_AccessFor_concurrent(t *testing.T) {
t.Skipf("TODO - Add a fix for this test")
testUser := &user.DefaultInfo{Name: "test-user"}
asCache := &spyCache{accessStoreCache: cache.NewLRUExpireCache(100)}
store := &AccessStore{
roles: roleRevisionsMock(func(ns, name string) string {
return fmt.Sprintf("%s%srev", ns, name)
}),
usersPolicyRules: &policyRulesMock{
getRBFunc: func(s string) []*rbacv1.RoleBinding {
return []*rbacv1.RoleBinding{
makeRB("testns", "testrb", testUser.Name, "testrole"),
}
},
getFunc: func(_ string) *AccessSet {
return &AccessSet{
set: map[key]resourceAccessSet{
{"get", corev1.Resource("ConfigMap")}: map[Access]bool{
{Namespace: All, ResourceName: All}: true,
},
},
}
},
},
cache: asCache,
}

const n = 5 // observation showed cases with up to 5 (or more) concurrent queries for the same user

wait := make(chan struct{})
var wg sync.WaitGroup
var id string
for range n {
wg.Add(1)
go func() {
<-wait
id = store.AccessFor(testUser).ID
wg.Done()
}()
}
close(wait)
wg.Wait()

if got, want := len(asCache.setCalls), 1; got != want {
t.Errorf("Unexpected number of cache entries: got %d, want %d", got, want)
}
if got, want := asCache.setCalls[id], 1; got != want {
t.Errorf("Unexpected number of calls to cache.Set(): got %d, want %d", got, want)
}
}

func makeRB(ns, name, user, role string) *rbacv1.RoleBinding {
return &rbacv1.RoleBinding{
ObjectMeta: metav1.ObjectMeta{Namespace: ns, Name: name},
Expand Down

0 comments on commit 5c1a562

Please sign in to comment.