Skip to content

Commit

Permalink
Refactor approved users storage methods to use context for timeout ma…
Browse files Browse the repository at this point in the history
…nagement
  • Loading branch information
umputun committed Dec 29, 2024
1 parent 4d6a232 commit e1c6fae
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 97 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ data/spam-dynamic.txt
data/ham-dynamic.txt
data/approved-users.txt
tg-spam.db
tg-spam.db-shm
tg-spam.db-wal
logs/
site/public/
_examples/simplechat/messages.db
6 changes: 5 additions & 1 deletion app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type options struct {

HistoryDuration time.Duration `long:"history-duration" env:"HISTORY_DURATION" default:"24h" description:"history duration"`
HistoryMinSize int `long:"history-min-size" env:"HISTORY_MIN_SIZE" default:"1000" description:"history minimal size to keep"`
StorageTimeout time.Duration `long:"storage-timeout" env:"STORAGE_TIMEOUT" default:"0s" description:"storage timeout"`

Logger struct {
Enabled bool `long:"enabled" env:"ENABLED" description:"enable spam rotated logs"`
Expand Down Expand Up @@ -212,7 +213,7 @@ func execute(ctx context.Context, opts options) error {
log.Printf("[DEBUG] data db: %s", dataFile)

// make store and load approved users
approvedUsersStore, auErr := storage.NewApprovedUsers(dataDB)
approvedUsersStore, auErr := storage.NewApprovedUsers(ctx, dataDB)
if auErr != nil {
return fmt.Errorf("can't make approved users store, %w", auErr)
}
Expand Down Expand Up @@ -437,6 +438,9 @@ func makeDetector(opts options) *tgspam.Detector {
detectorConfig.FirstMessageOnly = false
detectorConfig.FirstMessagesCount = 0
}
if opts.StorageTimeout > 0 { // if StorageTimeout is non-zero, set it. If zero, storage timeout is disabled
detectorConfig.StorageTimeout = opts.StorageTimeout
}

detector := tgspam.NewDetector(detectorConfig)
log.Printf("[DEBUG] detector config: %+v", detectorConfig)
Expand Down
54 changes: 24 additions & 30 deletions app/storage/approved_users.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"fmt"
"log"
"sync"
Expand All @@ -12,56 +13,54 @@ import (
"github.com/umputun/tg-spam/lib/approved"
)

// ApprovedUsers is a storage for approved users ids
// ApprovedUsers is a storage for approved users
type ApprovedUsers struct {
db *sqlx.DB
lock *sync.RWMutex
}

// ApprovedUsersInfo represents information about an approved user.
type approvedUsersInfo struct {
UserID string `db:"id"`
UserName string `db:"name"`
Timestamp time.Time `db:"timestamp"`
}

// NewApprovedUsers creates a new ApprovedUsers storage
func NewApprovedUsers(db *sqlx.DB) (*ApprovedUsers, error) {
func NewApprovedUsers(ctx context.Context, db *sqlx.DB) (*ApprovedUsers, error) {
if db == nil {
return nil, fmt.Errorf("db connection is nil")
}
if err := setSqlitePragma(db); err != nil {
return nil, fmt.Errorf("failed to set sqlite pragma: %w", err)
}

// migrate the table if necessary
err := migrateTable(db)
if err != nil {
if err := migrateTable(db); err != nil {
return nil, err
}

// create the table if it doesn't exist
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS approved_users (
query := `CREATE TABLE IF NOT EXISTS approved_users (
id TEXT PRIMARY KEY,
name TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
)`
if _, err := db.ExecContext(ctx, query); err != nil {
return nil, fmt.Errorf("failed to create approved_users table: %w", err)
}

return &ApprovedUsers{db: db, lock: &sync.RWMutex{}}, nil
}

// Read returns all approved users.
func (au *ApprovedUsers) Read() ([]approved.UserInfo, error) {
// Read returns a list of all approved users
func (au *ApprovedUsers) Read(ctx context.Context) ([]approved.UserInfo, error) {
au.lock.RLock()
defer au.lock.RUnlock()

users := []approvedUsersInfo{}
err := au.db.Select(&users, "SELECT id, name, timestamp FROM approved_users ORDER BY timestamp DESC")
err := au.db.SelectContext(ctx, &users, "SELECT id, name, timestamp FROM approved_users ORDER BY timestamp DESC")
if err != nil {
return nil, fmt.Errorf("failed to get approved users: %w", err)
}

res := make([]approved.UserInfo, len(users))
for i, u := range users {
res[i] = approved.UserInfo{
Expand All @@ -74,35 +73,35 @@ func (au *ApprovedUsers) Read() ([]approved.UserInfo, error) {
return res, nil
}

// Write writes new user info to the storage
func (au *ApprovedUsers) Write(user approved.UserInfo) error {
// Write adds a new approved user
func (au *ApprovedUsers) Write(ctx context.Context, user approved.UserInfo) error {
au.lock.Lock()
defer au.lock.Unlock()

if user.Timestamp.IsZero() {
user.Timestamp = time.Now()
}
// Prepare the query to insert a new record, ignoring if it already exists

query := "INSERT OR IGNORE INTO approved_users (id, name, timestamp) VALUES (?, ?, ?)"
if _, err := au.db.Exec(query, user.UserID, user.UserName, user.Timestamp); err != nil {
if _, err := au.db.ExecContext(ctx, query, user.UserID, user.UserName, user.Timestamp); err != nil {
return fmt.Errorf("failed to insert user %+v: %w", user, err)
}
log.Printf("[INFO] user %s added to approved users", user.String())
return nil
}

// Delete deletes the given id from the storage
func (au *ApprovedUsers) Delete(id string) error {
// Delete removes an approved user by its ID
func (au *ApprovedUsers) Delete(ctx context.Context, id string) error {
au.lock.Lock()
defer au.lock.Unlock()
var user approvedUsersInfo

// retrieve the user's name and timestamp for logging purposes
err := au.db.Get(&user, "SELECT id, name, timestamp FROM approved_users WHERE id = ?", id)
var user approvedUsersInfo
err := au.db.GetContext(ctx, &user, "SELECT id, name, timestamp FROM approved_users WHERE id = ?", id)
if err != nil {
return fmt.Errorf("failed to get approved user for id %s: %w", id, err)
}

if _, err := au.db.Exec("DELETE FROM approved_users WHERE id = ?", id); err != nil {
if _, err := au.db.ExecContext(ctx, "DELETE FROM approved_users WHERE id = ?", id); err != nil {
return fmt.Errorf("failed to delete id %s: %w", id, err)
}

Expand All @@ -111,18 +110,16 @@ func (au *ApprovedUsers) Delete(id string) error {
}

func migrateTable(db *sqlx.DB) error {
// Check if the table exists
var exists int
err := db.Get(&exists, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='approved_users'")
if err != nil {
return fmt.Errorf("failed to check for approved_users table existence: %w", err)
}

if exists == 0 {
return nil // Table does not exist, no need to migrate
return nil
}

// Get column info for 'id' column
var cols []struct {
CID int `db:"cid"`
Name string `db:"name"`
Expand All @@ -136,12 +133,9 @@ func migrateTable(db *sqlx.DB) error {
return fmt.Errorf("failed to get table info for approved_users: %w", err)
}

// Check if 'id' column is INTEGER
for _, col := range cols {
if col.Name == "id" && col.Type == "INTEGER" {
// Drop the table if 'id' is of type INTEGER
_, err = db.Exec("DROP TABLE approved_users")
if err != nil {
if _, err = db.Exec("DROP TABLE approved_users"); err != nil {
return fmt.Errorf("failed to drop old approved_users table: %w", err)
}
break
Expand Down
Loading

0 comments on commit e1c6fae

Please sign in to comment.