Skip to content

Commit

Permalink
use fsnotify for basic auth userfile updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Samu Tamminen committed Jan 13, 2022
1 parent 6f08125 commit e4aeaba
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 80 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4 // indirect
github.com/fatih/color v1.12.0
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/ghodss/yaml v1.0.0
github.com/go-chi/chi/v5 v5.0.3
github.com/go-zookeeper/zk v1.0.2
Expand Down
124 changes: 57 additions & 67 deletions pkg/filter/validator/basicauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ import (
"encoding/base64"
"fmt"
"io"
"io/fs"
"os"
"strings"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"github.com/tg123/go-htpasswd"
"golang.org/x/crypto/bcrypt"

Expand Down Expand Up @@ -62,15 +60,14 @@ type (
// AuthorizedUsersCache provides cached lookup for authorized users.
AuthorizedUsersCache interface {
Match(string, string) bool
WatchChanges() error
Refresh() error
WatchChanges()
Close()
}

htpasswdUserCache struct {
userFile string
userFileObject *htpasswd.File
fileMutex sync.RWMutex
watcher *fsnotify.Watcher
syncInterval time.Duration
stopCtx context.Context
cancel context.CancelFunc
Expand Down Expand Up @@ -118,67 +115,59 @@ func newHtpasswdUserCache(userFile string, syncInterval time.Duration) *htpasswd
stopCtx, cancel := context.WithCancel(context.Background())
userFileObject, err := htpasswd.New(userFile, htpasswd.DefaultSystems, nil)
if err != nil {
panic(err)
logger.Errorf(err.Error())
userFileObject = nil
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.Errorf(err.Error())
watcher = nil
}
return &htpasswdUserCache{
userFile: userFile,
stopCtx: stopCtx,
cancel: cancel,
watcher: watcher,
userFileObject: userFileObject,
// Removed access or updated passwords are updated according syncInterval.
syncInterval: syncInterval,
}
}

// Refresh reloads users from userFile.
func (huc *htpasswdUserCache) Refresh() error {
huc.fileMutex.RLock()
err := huc.userFileObject.Reload(nil)
huc.fileMutex.RUnlock()
return err
}

func (huc *htpasswdUserCache) WatchChanges() error {
getFileStat := func() (fs.FileInfo, error) {
huc.fileMutex.RLock()
stat, err := os.Stat(huc.userFile)
huc.fileMutex.RUnlock()
return stat, err
}

initialStat, err := getFileStat()
if err != nil {
return err
func (huc *htpasswdUserCache) WatchChanges() {
if huc.userFileObject == nil || huc.watcher == nil {
return
}
for {
stat, err := getFileStat()
if err != nil {
return err
}
if stat.Size() != initialStat.Size() || stat.ModTime() != initialStat.ModTime() {
err := huc.Refresh()
if err != nil {
return err
go func() {
for {
select {
case _, ok := <-huc.watcher.Events:
if !ok {
return
}
err := huc.userFileObject.Reload(nil)
if err != nil {
logger.Errorf(err.Error())
}
case err, ok := <-huc.watcher.Errors:
if !ok {
return
}
logger.Errorf(err.Error())
}

// reset initial stat and watch for next modification
initialStat, err = getFileStat()
if err != nil {
return err
}
}
select {
case <-time.After(huc.syncInterval):
continue
case <-huc.stopCtx.Done():
return nil
}
}()
err := huc.watcher.Add(huc.userFile)
if err != nil {
logger.Errorf(err.Error())
}
return nil
return
}

func (huc *htpasswdUserCache) Close() {
huc.cancel()
if huc.watcher != nil {
huc.watcher.Close()
}
}

func (huc *htpasswdUserCache) Match(username string, password string) bool {
Expand Down Expand Up @@ -260,10 +249,10 @@ func kvsToReader(kvs map[string]string) io.Reader {
return strings.NewReader(stringData)
}

func (euc *etcdUserCache) WatchChanges() error {
func (euc *etcdUserCache) WatchChanges() {
if euc.prefix == "" {
logger.Errorf("missing etcd prefix, skip watching changes")
return nil
return
}
var (
syncer *cluster.Syncer
Expand All @@ -285,22 +274,25 @@ func (euc *etcdUserCache) WatchChanges() error {
select {
case <-time.After(10 * time.Second):
case <-euc.stopCtx.Done():
return nil
return
}
}
defer syncer.Close()

for {
select {
case <-euc.stopCtx.Done():
return nil
case kvs := <-ch:
logger.Infof("basic auth credentials update")
pwReader := kvsToReader(kvs)
euc.userFileObject.ReloadFromReader(pwReader, nil)
// start listening in background
go func() {
defer syncer.Close()

for {
select {
case <-euc.stopCtx.Done():
return
case kvs := <-ch:
logger.Infof("basic auth credentials update")
pwReader := kvsToReader(kvs)
euc.userFileObject.ReloadFromReader(pwReader, nil)
}
}
}
return nil
}()
return
}

func (euc *etcdUserCache) Close() {
Expand All @@ -310,8 +302,6 @@ func (euc *etcdUserCache) Close() {
euc.cancel()
}

func (euc *etcdUserCache) Refresh() error { return nil }

func (euc *etcdUserCache) Match(username string, password string) bool {
if euc.prefix == "" {
return false
Expand All @@ -335,7 +325,7 @@ func NewBasicAuthValidator(spec *BasicAuthValidatorSpec, supervisor *supervisor.
logger.Errorf("BasicAuth validator spec unvalid.")
return nil
}
go cache.WatchChanges()
cache.WatchChanges()
bav := &BasicAuthValidator{
spec: spec,
authorizedUsersCache: cache,
Expand Down
73 changes: 60 additions & 13 deletions pkg/filter/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ func prepareCtxAndHeader() (*contexttest.MockedHTTPContext, http.Header) {
return ctx, header
}

func cleanFile(userFile *os.File) {
err := userFile.Truncate(0)
check(err)
_, err = userFile.Seek(0, 0)
check(err)
userFile.Write([]byte(""))
}

func TestBasicAuth(t *testing.T) {
userIds := []string{
"userY", "userZ", "nonExistingUser",
Expand All @@ -330,6 +338,19 @@ func TestBasicAuth(t *testing.T) {
encrypt("userpasswordY"), encrypt("userpasswordZ"), encrypt("userpasswordX"),
}

t.Run("unexisting userFile", func(t *testing.T) {
yamlSpec := `
kind: Validator
name: validator
basicAuth:
mode: FILE
userFile: unexisting-file`
v := createValidator(yamlSpec, nil, nil)
ctx, _ := prepareCtxAndHeader()
if v.Handle(ctx) != resultInvalid {
t.Errorf("should be invalid")
}
})
t.Run("credentials from userFile", func(t *testing.T) {
userFile, err := os.CreateTemp("", "apache2-htpasswd")
check(err)
Expand All @@ -341,11 +362,21 @@ basicAuth:
mode: FILE
userFile: ` + userFile.Name()

// test invalid format
userFile.Write([]byte("keypass"))
v := createValidator(yamlSpec, nil, nil)
ctx, _ := prepareCtxAndHeader()
if v.Handle(ctx) != resultInvalid {
t.Errorf("should be invalid")
}

// now proper format
cleanFile(userFile)
userFile.Write(
[]byte(userIds[0] + ":" + encryptedPasswords[0] + "\n" + userIds[1] + ":" + encryptedPasswords[1]))
expectedValid := []bool{true, true, false}

v := createValidator(yamlSpec, nil, nil)
v = createValidator(yamlSpec, nil, nil)
for i := 0; i < 3; i++ {
ctx, header := prepareCtxAndHeader()
b64creds := base64.StdEncoding.EncodeToString([]byte(userIds[i] + ":" + passwords[i]))
Expand All @@ -362,25 +393,41 @@ basicAuth:
}
}

err = userFile.Truncate(0)
check(err)
_, err = userFile.Seek(0, 0)
check(err)
userFile.Write([]byte("")) // no more authorized users
v.basicAuth.authorizedUsersCache.Refresh()
cleanFile(userFile) // no more authorized users

ctx, header := prepareCtxAndHeader()
b64creds := base64.StdEncoding.EncodeToString([]byte(userIds[0] + ":" + passwords[0]))
header.Set("Authorization", "Basic "+b64creds)
result := v.Handle(ctx)
if result != resultInvalid {
t.Errorf("should be unauthorized")
tryCount := 5
for i := 0; i <= tryCount; i++ {
time.Sleep(200 * time.Millisecond) // wait that cache item gets deleted
ctx, header := prepareCtxAndHeader()
b64creds := base64.StdEncoding.EncodeToString([]byte(userIds[0] + ":" + passwords[0]))
header.Set("Authorization", "Basic "+b64creds)
result := v.Handle(ctx)
if result == resultInvalid {
break // successfully unauthorized
}
if i == tryCount && result != resultInvalid {
t.Errorf("should be unauthorized")
}
}

os.Remove(userFile.Name())
v.Close()
})

t.Run("test kvsToReader", func(t *testing.T) {
kvs := make(map[string]string)
kvs["/creds/key1"] = "key: key1\npass: pw" // invalid
kvs["/creds/key2"] = "ky: key2\npassword: pw" // invalid
kvs["/creds/key3"] = "key: key3\npassword: pw" // valid
reader := kvsToReader(kvs)
b, err := io.ReadAll(reader)
check(err)
s := string(b)
if s != "key3:pw" {
t.Errorf("parsing failed, %s", s)
}
})

t.Run("credentials from etcd", func(t *testing.T) {
etcdDirName, err := ioutil.TempDir("", "etcd-validator-test")
check(err)
Expand Down

0 comments on commit e4aeaba

Please sign in to comment.