Skip to content
This repository has been archived by the owner on Mar 29, 2024. It is now read-only.

Fix and improve error handling of user account mgmt #171

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions access/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func updateUserWithFailedRequest(statusCode int, disableSubscription bool) {
// Get user from database.
user, err := GetUser()
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
if !errors.Is(err, ErrNotLoggedIn) {
log.Warningf("spn/access: failed to get user to update with failed request: %s", err)
}
return
Expand Down Expand Up @@ -216,7 +216,7 @@ func Login(username, password string) (user *UserRecord, code int, err error) {
// Get previous user.
previousUser, err := GetUser()
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
if !errors.Is(err, ErrNotLoggedIn) {
log.Warningf("spn/access: failed to get previous for re-login: %s", err)
}
previousUser = nil
Expand Down Expand Up @@ -325,7 +325,7 @@ func Logout(shallow, purge bool) error {
// Else, just update the user.
user, err := GetUser()
if err != nil {
if errors.Is(err, database.ErrNotFound) {
if errors.Is(err, ErrNotLoggedIn) {
return nil
}
return fmt.Errorf("failed to load user for logout: %w", err)
Expand Down
23 changes: 20 additions & 3 deletions access/database.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package access

import (
"errors"
"fmt"
"net/http"
"sync"
Expand Down Expand Up @@ -87,31 +88,44 @@ func (authToken *AuthTokenRecord) Update(resp *http.Response) error {
}

var (
cachedUser *UserRecord
cachedAuthToken *AuthTokenRecord
accountCacheLock sync.Mutex

cachedUser *UserRecord
cachedUserSet bool

cachedAuthToken *AuthTokenRecord
)

func clearUserCaches() {
accountCacheLock.Lock()
defer accountCacheLock.Unlock()

cachedUser = nil
cachedUserSet = false
cachedAuthToken = nil
}

// GetUser returns the current user account.
// Returns nil when no user is logged in.
func GetUser() (*UserRecord, error) {
// Check cache.
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
if cachedUser != nil {
if cachedUserSet {
if cachedUser == nil {
return nil, ErrNotLoggedIn
}
return cachedUser, nil
}

// Load from disk.
r, err := db.Get(userRecordKey)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
cachedUser = nil
cachedUserSet = true
return nil, ErrNotLoggedIn
}
return nil, err
}

Expand All @@ -124,6 +138,7 @@ func GetUser() (*UserRecord, error) {
return nil, err
}
cachedUser = newUser
cachedUserSet = true
return cachedUser, nil
}

Expand All @@ -133,6 +148,7 @@ func GetUser() (*UserRecord, error) {
return nil, fmt.Errorf("record not of type *UserRecord, but %T", r)
}
cachedUser = newUser
cachedUserSet = true
return cachedUser, nil
}

Expand All @@ -142,6 +158,7 @@ func (user *UserRecord) Save() error {
accountCacheLock.Lock()
defer accountCacheLock.Unlock()
cachedUser = user
cachedUserSet = true

// Update view if unset.
if user.View == nil {
Expand Down
3 changes: 1 addition & 2 deletions captain/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/tevino/abool"

"github.com/safing/portbase/database"
"github.com/safing/portbase/log"
"github.com/safing/portbase/notifications"
"github.com/safing/portmaster/netenv"
Expand Down Expand Up @@ -206,7 +205,7 @@ func clientCheckNetworkReady(ctx context.Context) clientComponentResult {
func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult {
// Get SPN user.
user, err := access.GetUser()
if err != nil && !errors.Is(err, database.ErrNotFound) {
if err != nil && !errors.Is(err, access.ErrNotLoggedIn) {
notifications.NotifyError(
"spn:failed-to-get-user",
"SPN Internal Error",
Expand Down