Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sessions/mysql: Del expired sessions on startup. #1669

Merged
merged 2 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions politeiawww/sessions/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,12 @@ MaxAge of <= 0.
The key used to encode/decode the session ID and the session values is provided
to the session store on initialization. Keys can be rotated by providing
multiple keys on initialization.
The session store does not delete expired sessions from the database. The
gorilla/sessions API does not allow the session ID to be retrieved from a
session cookie once the session has expired, so there is no way for the session
store to know what IDs needs to be deleted from the database. The database
layer must track when the session was created and manually delete expired
sessions.
*/
package sessions
2 changes: 1 addition & 1 deletion politeiawww/sessions/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ func UseLogger(logger slog.Logger) {

// Initialize the package logger.
func init() {
UseLogger(logger.NewSubsystem("SESS"))
UseLogger(logger.NewSubsystem("SESN"))
}
2 changes: 1 addition & 1 deletion politeiawww/sessions/mysql/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ func UseLogger(logger slog.Logger) {

// Initialize the package logger.
func init() {
UseLogger(logger.NewSubsystem("SESS"))
UseLogger(logger.NewSubsystem("SNDB"))
}
227 changes: 148 additions & 79 deletions politeiawww/sessions/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,23 @@ import (
"github.com/pkg/errors"
)

const (
// defaultTableName is the default table name for the sessions table.
defaultTableName = "sessions"

// defaultOpTimeout is the default timeout for a single database operation.
defaultOpTimeout = 1 * time.Minute
)

// tableSessions defines the sessions table.
// sessionsTable is the table for the encoded session values.
//
// The id column is 128 bytes so that it can accomidate a 64 byte base64,
// base32, or hex encoded key.
//
// id column is 128 bytes so that it can accomidate a 64 byte base64, base32,
// or hex encoded key.
// The encoded_session column has a max length of 2^16 bytes, which is around
// 64KB.
//
// encoded_session column max length is up to 2^16 bytes which is around 64KB.
const tableSessions = `
id CHAR(128) NOT NULL PRIMARY KEY,
encoded_session BLOB NOT NULL
// The created_at column contains a Unix timestamp and is used to manually
// clean up expired sessions. The gorilla/sessions Store does not do this
// automatically.
const sessionsTable = `
id CHAR(128) PRIMARY KEY,
encoded_session BLOB NOT NULL,
created_at BIGINT NOT NULL
`

// Opts includes configurable options for the sessions database.
type Opts struct {
// TableName is the table name for the sessions table. Defaults to
// "sessions".
TableName string

// OpTimeout is the timeout for a single database operation. Defaults to
// 1 minute.
OpTimeout time.Duration
}

var (
_ sessions.DB = (*mysql)(nil)
)
Expand All @@ -54,24 +41,85 @@ type mysql struct {
// db is the mysql DB context.
db *sql.DB

// opts includes the sessions database options.
// sessionMaxAge is the max age of a session in seconds. This is used to
// periodically clean up expired sessions from the database. The
// gorilla/sessions Store implemenation does not do this automatically. It
// must be done manually in the database layer.
sessionMaxAge int64

// opts contains the session database options.
opts *Opts
}

// ctxForOp returns a context and cancel function for a single database
// operation. It uses the database operation timeout set on the mysql
// context.
func (m *mysql) ctxForOp() (context.Context, func()) {
return context.WithTimeout(context.Background(), m.opts.OpTimeout)
// Opts contains configurable options for the sessions database. These are
// not required. Sane defaults are used when the options are not provided.
type Opts struct {
// TableName is the table name for the sessions table.
TableName string

// OpTimeout is the timeout for a single database operation.
OpTimeout time.Duration
}

const (
// defaultTableName is the default table name for the sessions table.
defaultTableName = "sessions"

// defaultOpTimeout is the default timeout for a single database operation.
defaultOpTimeout = 1 * time.Minute
)

// New returns a new mysql context that implements the sessions DB interface.
// The opts param can be used to override the default mysql context settings.
//
// The sessionMaxAge is the max age in seconds of a session. This function
// cleans up any expired sessions from the database as part of the
// initialization. A sessionMaxAge of <=0 will cause the sessions database
// to be dropped and recreated.
func New(db *sql.DB, sessionMaxAge int64, opts *Opts) (*mysql, error) {
// Setup the database options
if opts == nil {
opts = &Opts{}
}
if opts.TableName == "" {
opts.TableName = defaultTableName
}
if opts.OpTimeout == 0 {
opts.OpTimeout = defaultOpTimeout
}

// Setup the mysql context
m := mysql{
db: db,
sessionMaxAge: sessionMaxAge,
opts: opts,
}

// Perform database setup
if sessionMaxAge <= 0 {
err := m.dropTable()
if err != nil {
return nil, err
}
}
err := m.createTable()
if err != nil {
return nil, err
}
err = m.cleanup()
if err != nil {
return nil, err
}

return &m, nil
}

// Save saves a session to the database.
//
// Save satisfies the sessions.DB interface.
func (m *mysql) Save(sessionID string, s sessions.EncodedSession) error {
log.Tracef("Save: %v", sessionID)
log.Tracef("Save %v", sessionID)

// Marshal encoded session
es, err := json.Marshal(s)
if err != nil {
return err
Expand All @@ -80,12 +128,13 @@ func (m *mysql) Save(sessionID string, s sessions.EncodedSession) error {
ctx, cancel := m.ctxForOp()
defer cancel()

// Save session to database
q := fmt.Sprintf(`INSERT INTO %v
(id, encoded_session) VALUES (?, ?)
ON DUPLICATE KEY UPDATE
encoded_session = VALUES(encoded_session)`, m.opts.TableName)
_, err = m.db.ExecContext(ctx, q, sessionID, es)
q := `INSERT INTO %v
(id, encoded_session, created_at) VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE
encoded_session = VALUES(encoded_session)`

q = fmt.Sprintf(q, m.opts.TableName)
_, err = m.db.ExecContext(ctx, q, sessionID, es, time.Now().Unix())
if err != nil {
return errors.WithStack(err)
}
Expand All @@ -98,16 +147,15 @@ func (m *mysql) Save(sessionID string, s sessions.EncodedSession) error {
//
// Del satisfies the sessions.DB interface.
func (m *mysql) Del(sessionID string) error {
log.Tracef("Del: %v", sessionID)
log.Tracef("Del %v", sessionID)

ctx, cancel := m.ctxForOp()
defer cancel()

// Delete session
_, err := m.db.ExecContext(ctx,
"DELETE FROM "+m.opts.TableName+" WHERE id = ?", sessionID)
q := fmt.Sprintf("DELETE FROM %v WHERE id = ?", m.opts.TableName)
_, err := m.db.ExecContext(ctx, q, sessionID)
if err != nil {
return err
return errors.WithStack(err)
}

return nil
Expand All @@ -118,24 +166,23 @@ func (m *mysql) Del(sessionID string) error {
//
// Get statisfies the sessions.DB interface.
func (m *mysql) Get(sessionID string) (*sessions.EncodedSession, error) {
log.Tracef("Get: %v", sessionID)
log.Tracef("Get %v", sessionID)

ctx, cancel := m.ctxForOp()
defer cancel()

// Get session
q := fmt.Sprintf("SELECT encoded_session FROM %v WHERE id = ?",
m.opts.TableName)

var encodedBlob []byte
err := m.db.QueryRowContext(ctx,
"SELECT encoded_session FROM "+m.opts.TableName+" WHERE id = ?",
sessionID).Scan(&encodedBlob)
err := m.db.QueryRowContext(ctx, q, sessionID).Scan(&encodedBlob)
switch {
case err == sql.ErrNoRows:
return nil, sessions.ErrNotFound
case err != nil:
return nil, err
return nil, errors.WithStack(err)
}

// Decode blob
var es sessions.EncodedSession
err = json.Unmarshal(encodedBlob, &es)
if err != nil {
Expand All @@ -145,41 +192,63 @@ func (m *mysql) Get(sessionID string) (*sessions.EncodedSession, error) {
return &es, nil
}

// New returns a new mysql context that implements the sessions DB interface.
// The opts param can be used to override the default mysql context settings.
func New(db *sql.DB, opts *Opts) (*mysql, error) {
// Setup database options.
tableName := defaultTableName
opTimeout := defaultOpTimeout
// Override defaults if options are provided
if opts != nil {
if opts.TableName != "" {
tableName = opts.TableName
}
if opts.OpTimeout != 0 {
opTimeout = opts.OpTimeout
}
// createTable creates the sessions table.
func (m *mysql) createTable() error {
ctx, cancel := m.ctxForOp()
defer cancel()

q := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %v (%v)",
m.opts.TableName, sessionsTable)
_, err := m.db.ExecContext(ctx, q)
if err != nil {
return errors.WithStack(err)
}

// Create mysql context
m := mysql{
db: db,
opts: &Opts{
TableName: tableName,
OpTimeout: opTimeout,
},
log.Debugf("Created %v database table", m.opts.TableName)

return nil
}

// dropTable drops the sessions table.
func (m *mysql) dropTable() error {
ctx, cancel := m.ctxForOp()
defer cancel()

q := fmt.Sprintf("DROP TABLE IF EXISTS %v", m.opts.TableName)
_, err := m.db.ExecContext(ctx, q)
if err != nil {
return errors.WithStack(err)
}

log.Debugf("Dropped %v database table", m.opts.TableName)

return nil
}

// cleanup performs database cleanup by deleting all sessions that have
// expired.
func (m *mysql) cleanup() error {
ctx, cancel := m.ctxForOp()
defer cancel()

// Create sessions table
q := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`,
m.opts.TableName, tableSessions)
_, err := db.ExecContext(ctx, q)
q := "DELETE FROM %v WHERE created_at + ? <= ?"
q = fmt.Sprintf(q, m.opts.TableName)
r, err := m.db.ExecContext(ctx, q, m.sessionMaxAge, time.Now().Unix())
if err != nil {
return nil, errors.WithStack(err)
return errors.WithStack(err)
}
rowsAffected, err := r.RowsAffected()
if err != nil {
return err
}

return &m, nil
log.Debugf("Deleted %v expired sessions from the database", rowsAffected)

return nil
}

// ctxForOp returns a context and cancel function for a single database
// operation.
func (m *mysql) ctxForOp() (context.Context, func()) {
return context.WithTimeout(context.Background(), m.opts.OpTimeout)
}
Loading