Skip to content

Commit

Permalink
refactor storage with the ability to set different engines with diffe…
Browse files Browse the repository at this point in the history
…rent locking support
  • Loading branch information
umputun committed Dec 30, 2024
1 parent 46f8dfb commit e9021ca
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 195 deletions.
9 changes: 4 additions & 5 deletions app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/go-pkgz/lgr"
"github.com/go-pkgz/rest"
"github.com/jessevdk/go-flags"
"github.com/jmoiron/sqlx"
"github.com/sashabaranov/go-openai"
"gopkg.in/natefinch/lumberjack.v2"

Expand Down Expand Up @@ -352,7 +351,7 @@ func checkVolumeMount(opts options) (ok bool) {
return false
}

func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc *storage.Locator, dataDB *sqlx.DB) (err error) {
func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc *storage.Locator, db *storage.Engine) (err error) {
authPassswd := opts.Server.AuthPasswd
if opts.Server.AuthPasswd == "auto" {
authPassswd, err = webapi.GenerateRandomPassword(20)
Expand All @@ -367,7 +366,7 @@ func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc *
}

// make store and load approved users
detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB)
detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, db)
if auErr != nil {
return fmt.Errorf("can't make approved users store, %w", auErr)
}
Expand Down Expand Up @@ -508,7 +507,7 @@ func makeDetector(opts options) *tgspam.Detector {
return detector
}

func makeSpamBot(ctx context.Context, opts options, dataDB *sqlx.DB, detector *tgspam.Detector) (*bot.SpamFilter, error) {
func makeSpamBot(ctx context.Context, opts options, dataDB *storage.Engine, detector *tgspam.Detector) (*bot.SpamFilter, error) {
if dataDB == nil || detector == nil {
return nil, errors.New("nil datadb or detector")
}
Expand Down Expand Up @@ -577,7 +576,7 @@ func (n nopWriteCloser) Close() error { return nil }

// makeSpamLogger creates spam logger to keep reports about spam messages
// it writes json lines to the provided writer
func makeSpamLogger(ctx context.Context, wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) {
func makeSpamLogger(ctx context.Context, wr io.Writer, dataDB *storage.Engine) (events.SpamLogger, error) {
// make store and load approved users
detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB)
if auErr != nil {
Expand Down
19 changes: 9 additions & 10 deletions app/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -26,7 +25,7 @@ func TestMakeSpamLogger(t *testing.T) {
require.NoError(t, err)
defer os.Remove(file.Name())

db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()

Expand Down Expand Up @@ -380,7 +379,7 @@ func Test_migrateSamples(t *testing.T) {
opts.Files.SamplesDataPath, opts.Files.DynamicDataPath = tmpDir, tmpDir

t.Run("full migration", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
store, err := storage.NewSamples(context.Background(), db)
Expand Down Expand Up @@ -428,7 +427,7 @@ func Test_migrateSamples(t *testing.T) {
})

t.Run("already migrated", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
store, err := storage.NewSamples(context.Background(), db)
Expand All @@ -452,7 +451,7 @@ func Test_migrateSamples(t *testing.T) {
})

t.Run("partial migration", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
store, err := storage.NewSamples(context.Background(), db)
Expand All @@ -478,7 +477,7 @@ func Test_migrateSamples(t *testing.T) {
})

t.Run("empty files", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
store, err := storage.NewSamples(context.Background(), db)
Expand All @@ -503,7 +502,7 @@ func Test_migrateDicts(t *testing.T) {
})

t.Run("full migration", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
dict, err := storage.NewDictionary(context.Background(), db)
Expand Down Expand Up @@ -537,7 +536,7 @@ func Test_migrateDicts(t *testing.T) {
})

t.Run("already migrated", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
dict, err := storage.NewDictionary(context.Background(), db)
Expand Down Expand Up @@ -571,7 +570,7 @@ func Test_migrateDicts(t *testing.T) {
})

t.Run("empty files", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
dict, err := storage.NewDictionary(context.Background(), db)
Expand All @@ -592,7 +591,7 @@ func Test_migrateDicts(t *testing.T) {
})

t.Run("partial migration", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
db, err := storage.NewSqliteDB(":memory:")
require.NoError(t, err)
defer db.Close()
dict, err := storage.NewDictionary(context.Background(), db)
Expand Down
25 changes: 10 additions & 15 deletions app/storage/approved_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log"
"sync"
"time"

"github.com/jmoiron/sqlx"
Expand All @@ -15,8 +14,7 @@ import (

// ApprovedUsers is a storage for approved users
type ApprovedUsers struct {
db *sqlx.DB
lock *sync.RWMutex
db *Engine
}

type approvedUsersInfo struct {
Expand All @@ -26,15 +24,12 @@ type approvedUsersInfo struct {
}

// NewApprovedUsers creates a new ApprovedUsers storage
func NewApprovedUsers(ctx context.Context, db *sqlx.DB) (*ApprovedUsers, error) {
func NewApprovedUsers(ctx context.Context, db *Engine) (*ApprovedUsers, error) {
if db == nil {
return nil, fmt.Errorf("db connection is nil")
}
if err := setSqlitePragma(db); err != nil {
return nil, fmt.Errorf("failed to set sqlite pragma: %w", err)
}

if err := migrateTable(db); err != nil {
if err := migrateTable(&db.DB); err != nil {
return nil, err
}

Expand All @@ -47,13 +42,13 @@ func NewApprovedUsers(ctx context.Context, db *sqlx.DB) (*ApprovedUsers, error)
return nil, fmt.Errorf("failed to create approved_users table: %w", err)
}

return &ApprovedUsers{db: db, lock: &sync.RWMutex{}}, nil
return &ApprovedUsers{db: db}, nil
}

// Read returns a list of all approved users
func (au *ApprovedUsers) Read(ctx context.Context) ([]approved.UserInfo, error) {
au.lock.RLock()
defer au.lock.RUnlock()
au.db.RLock()
defer au.db.RUnlock()

users := []approvedUsersInfo{}
err := au.db.SelectContext(ctx, &users, "SELECT id, name, timestamp FROM approved_users ORDER BY timestamp DESC")
Expand All @@ -75,8 +70,8 @@ func (au *ApprovedUsers) Read(ctx context.Context) ([]approved.UserInfo, error)

// Write adds a new approved user
func (au *ApprovedUsers) Write(ctx context.Context, user approved.UserInfo) error {
au.lock.Lock()
defer au.lock.Unlock()
au.db.Lock()
defer au.db.Unlock()

if user.Timestamp.IsZero() {
user.Timestamp = time.Now()
Expand All @@ -92,8 +87,8 @@ func (au *ApprovedUsers) Write(ctx context.Context, user approved.UserInfo) erro

// Delete removes an approved user by its ID
func (au *ApprovedUsers) Delete(ctx context.Context, id string) error {
au.lock.Lock()
defer au.lock.Unlock()
au.db.Lock()
defer au.db.Unlock()

var user approvedUsersInfo
err := au.db.GetContext(ctx, &user, "SELECT id, name, timestamp FROM approved_users WHERE id = ?", id)
Expand Down
28 changes: 12 additions & 16 deletions app/storage/approved_users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -15,11 +14,10 @@ import (

func TestApprovedUsers_NewApprovedUsers(t *testing.T) {
t.Run("create new table", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
db, teardown := setupTestDB(t)
defer teardown()

_, err = NewApprovedUsers(context.Background(), db)
_, err := NewApprovedUsers(context.Background(), db)
require.NoError(t, err)

// check if the table and columns exist
Expand All @@ -34,12 +32,11 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) {
})

t.Run("table already exists", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
db, teardown := setupTestDB(t)
defer teardown()

// Create table with 'name' column
_, err = db.Exec(`CREATE TABLE approved_users (id TEXT PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`)
_, err := db.Exec(`CREATE TABLE approved_users (id TEXT PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`)
require.NoError(t, err)

_, err = NewApprovedUsers(context.Background(), db)
Expand All @@ -53,12 +50,11 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) {
})

t.Run("table exists with INTEGER id", func(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
db, teardown := setupTestDB(t)
defer teardown()

// Create table with INTEGER id
_, err = db.Exec(`CREATE TABLE approved_users (id INTEGER PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`)
_, err := db.Exec(`CREATE TABLE approved_users (id INTEGER PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`)
require.NoError(t, err)

_, err = NewApprovedUsers(context.Background(), db)
Expand All @@ -73,9 +69,9 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) {
}

func TestApprovedUsers_Write(t *testing.T) {
db, e := sqlx.Open("sqlite", ":memory:")
require.NoError(t, e)
defer db.Close()
db, teardown := setupTestDB(t)
defer teardown()

ctx := context.Background()
au, e := NewApprovedUsers(ctx, db)
require.NoError(t, e)
Expand Down
25 changes: 9 additions & 16 deletions app/storage/detected_spam.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,16 @@ import (
"fmt"
"log"
"strings"
"sync"
"time"

"github.com/jmoiron/sqlx"

"github.com/umputun/tg-spam/lib/spamcheck"
)

const maxDetectedSpamEntries = 500

// DetectedSpam is a storage for detected spam entries
type DetectedSpam struct {
db *sqlx.DB
lock *sync.RWMutex
db *Engine
}

// DetectedSpamInfo represents information about a detected spam entry.
Expand All @@ -35,13 +31,10 @@ type DetectedSpamInfo struct {
}

// NewDetectedSpam creates a new DetectedSpam storage
func NewDetectedSpam(ctx context.Context, db *sqlx.DB) (*DetectedSpam, error) {
func NewDetectedSpam(ctx context.Context, db *Engine) (*DetectedSpam, error) {
if db == nil {
return nil, fmt.Errorf("db connection is nil")
}
if err := setSqlitePragma(db); err != nil {
return nil, fmt.Errorf("failed to set sqlite pragma: %w", err)
}

_, err := db.Exec(`CREATE TABLE IF NOT EXISTS detected_spam (
id INTEGER PRIMARY KEY AUTOINCREMENT,
Expand All @@ -67,13 +60,13 @@ func NewDetectedSpam(ctx context.Context, db *sqlx.DB) (*DetectedSpam, error) {
return nil, fmt.Errorf("failed to create index on timestamp: %w", err)
}

return &DetectedSpam{db: db, lock: &sync.RWMutex{}}, nil
return &DetectedSpam{db: db}, nil
}

// Write adds a new detected spam entry
func (ds *DetectedSpam) Write(ctx context.Context, entry DetectedSpamInfo, checks []spamcheck.Response) error {
ds.lock.Lock()
defer ds.lock.Unlock()
ds.db.Lock()
defer ds.db.Unlock()

checksJSON, err := json.Marshal(checks)
if err != nil {
Expand All @@ -91,8 +84,8 @@ func (ds *DetectedSpam) Write(ctx context.Context, entry DetectedSpamInfo, check

// SetAddedToSamplesFlag sets the added flag to true for the detected spam entry with the given id
func (ds *DetectedSpam) SetAddedToSamplesFlag(ctx context.Context, id int64) error {
ds.lock.Lock()
defer ds.lock.Unlock()
ds.db.Lock()
defer ds.db.Unlock()

query := `UPDATE detected_spam SET added = 1 WHERE id = ?`
if _, err := ds.db.ExecContext(ctx, query, id); err != nil {
Expand All @@ -103,8 +96,8 @@ func (ds *DetectedSpam) SetAddedToSamplesFlag(ctx context.Context, id int64) err

// Read returns all detected spam entries
func (ds *DetectedSpam) Read(ctx context.Context) ([]DetectedSpamInfo, error) {
ds.lock.RLock()
defer ds.lock.RUnlock()
ds.db.RLock()
defer ds.db.RUnlock()

var entries []DetectedSpamInfo
err := ds.db.SelectContext(ctx, &entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries)
Expand Down
Loading

0 comments on commit e9021ca

Please sign in to comment.