Skip to content

Commit

Permalink
fix(middleware/session): mutex for thread safety (#3049)
Browse files Browse the repository at this point in the history
* fix(middleware/session): mutex for thread safety

* chore: Remove extra release and acquire ctx calls in session_test.go

* feat: Remove unnecessary session mutex lock in decodeSessionData function
  • Loading branch information
sixcolors authored Jun 29, 2024
1 parent dbba6cf commit 83731ce
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 24 deletions.
40 changes: 32 additions & 8 deletions middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx fiber.Ctx // fiber context
Expand Down Expand Up @@ -56,11 +57,15 @@ func releaseSession(s *Session) {

// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}

// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}

Expand Down Expand Up @@ -101,6 +106,9 @@ func (s *Session) Destroy() error {
// Reset local data
s.data.Reset()

s.mu.Lock()
defer s.mu.Unlock()

// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -113,6 +121,9 @@ func (s *Session) Destroy() error {

// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()

// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -137,6 +148,9 @@ func (s *Session) Reset() error {
// Reset expiration
s.exp = 0

s.mu.Lock()
defer s.mu.Unlock()

// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -153,10 +167,7 @@ func (s *Session) Reset() error {

// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
// Create a new id
s.id = s.config.KeyGenerator()

// We assign a new id to the session, so the session must be fresh
s.fresh = true
}

Expand All @@ -167,6 +178,9 @@ func (s *Session) Save() error {
return nil
}

s.mu.Lock()
defer s.mu.Unlock()

// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
Expand All @@ -176,25 +190,23 @@ func (s *Session) Save() error {
s.setSession()

// Convert data to bytes
mux.Lock()
defer mux.Unlock()
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}

// copy the data in buffer
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())

// pass copied bytes with session id to provider
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}

// Release session
// TODO: It's not safe to use the Session after called Save()
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)

return nil
Expand All @@ -210,6 +222,8 @@ func (s *Session) Keys() []string {

// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}

Expand Down Expand Up @@ -275,3 +289,13 @@ func (s *Session) delSession() {
fasthttp.ReleaseCookie(fcookie)
}
}

// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}
107 changes: 107 additions & 0 deletions middleware/session/session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package session

import (
"errors"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
})
})
}

// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()

var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test

// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()

localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})

sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Set a value
sess.Set("name", "john")

// get the session id
id := sess.ID()

// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}

// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}

// Release the context
app.ReleaseCtx(localCtx)

// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)

// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)

// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}

// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}

// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}

// Delete the key
sess.Delete("name")

// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}

// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}

wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent

// Check for errors sent to errChan
for err := range errChan {
require.NoError(t, err)
}
}
16 changes: 0 additions & 16 deletions middleware/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/gob"
"errors"
"fmt"
"sync"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
Expand All @@ -14,9 +13,6 @@ import (
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")

// mux is a global mutex for session operations.
var mux sync.Mutex

// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int

Expand Down Expand Up @@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error {
}
return s.Storage.Delete(id)
}

// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

0 comments on commit 83731ce

Please sign in to comment.