Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(middleware/session): mutex for thread safety #3049

Merged
merged 3 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx fiber.Ctx // fiber context
Expand Down Expand Up @@ -56,11 +57,15 @@

// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}

// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}

Expand Down Expand Up @@ -101,6 +106,9 @@
// Reset local data
s.data.Reset()

s.mu.Lock()
defer s.mu.Unlock()

// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -113,6 +121,9 @@

// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()

// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -137,6 +148,9 @@
// Reset expiration
s.exp = 0

s.mu.Lock()
defer s.mu.Unlock()

// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
Expand All @@ -153,10 +167,7 @@

// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
// Create a new id
s.id = s.config.KeyGenerator()

// We assign a new id to the session, so the session must be fresh
s.fresh = true
}

Expand All @@ -167,6 +178,9 @@
return nil
}

s.mu.Lock()
defer s.mu.Unlock()

// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
Expand All @@ -176,25 +190,23 @@
s.setSession()

// Convert data to bytes
mux.Lock()
defer mux.Unlock()
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}

// copy the data in buffer
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())

// pass copied bytes with session id to provider
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}

// Release session
// TODO: It's not safe to use the Session after called Save()
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)

return nil
Expand All @@ -210,6 +222,8 @@

// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}

Expand Down Expand Up @@ -275,3 +289,13 @@
fasthttp.ReleaseCookie(fcookie)
}
}

// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("failed to decode session data: %w", err)

Check warning on line 298 in middleware/session/session.go

View check run for this annotation

Codecov / codecov/patch

middleware/session/session.go#L298

Added line #L298 was not covered by tests
}
return nil
}
107 changes: 107 additions & 0 deletions middleware/session/session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package session

import (
"errors"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -856,3 +858,108 @@ func Benchmark_Session_Asserted_Parallel(b *testing.B) {
})
})
}

// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()

var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test

// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()

localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})

sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Set a value
sess.Set("name", "john")

// get the session id
id := sess.ID()

// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}

// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}

// Release the context
app.ReleaseCtx(localCtx)

// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)

// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)

// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}

// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}

// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}

// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}

// Delete the key
sess.Delete("name")

// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}

// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}

wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent

// Check for errors sent to errChan
for err := range errChan {
require.NoError(t, err)
}
}
16 changes: 0 additions & 16 deletions middleware/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/gob"
"errors"
"fmt"
"sync"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
Expand All @@ -14,9 +13,6 @@ import (
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")

// mux is a global mutex for session operations.
var mux sync.Mutex

// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int

Expand Down Expand Up @@ -132,15 +128,3 @@ func (s *Store) Delete(id string) error {
}
return s.Storage.Delete(id)
}

// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
mux.Lock()
defer mux.Unlock()
_, _ = s.byteBuffer.Write(rawData)
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}
Loading