Skip to content

Commit

Permalink
add ctx support for detected spam store
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Dec 29, 2024
1 parent e1c6fae commit 1281b49
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 48 deletions.
10 changes: 5 additions & 5 deletions app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func execute(ctx context.Context, opts options) error {
defer loggerWr.Close()

// make spam logger
spamLogger, err := makeSpamLogger(loggerWr, dataDB)
spamLogger, err := makeSpamLogger(ctx, loggerWr, dataDB)
if err != nil {
return fmt.Errorf("can't make spam logger, %w", err)
}
Expand Down Expand Up @@ -358,7 +358,7 @@ func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc *
}

// make store and load approved users
detectedSpamStore, auErr := storage.NewDetectedSpam(dataDB)
detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB)
if auErr != nil {
return fmt.Errorf("can't make approved users store, %w", auErr)
}
Expand Down Expand Up @@ -546,9 +546,9 @@ func (n nopWriteCloser) Close() error { return nil }

// makeSpamLogger creates spam logger to keep reports about spam messages
// it writes json lines to the provided writer
func makeSpamLogger(wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) {
func makeSpamLogger(ctx context.Context, wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) {
// make store and load approved users
detectedSpamStore, auErr := storage.NewDetectedSpam(dataDB)
detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB)
if auErr != nil {
return nil, fmt.Errorf("can't make approved users store, %w", auErr)
}
Expand Down Expand Up @@ -587,7 +587,7 @@ func makeSpamLogger(wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) {
UserName: msg.From.Username,
Timestamp: time.Now().In(time.Local),
}
if err := detectedSpamStore.Write(rec, response.CheckResults); err != nil {
if err := detectedSpamStore.Write(ctx, rec, response.CheckResults); err != nil {
log.Printf("[WARN] can't write to db, %v", err)
}
})
Expand Down
2 changes: 1 addition & 1 deletion app/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestMakeSpamLogger(t *testing.T) {
require.NoError(t, err)
defer db.Close()

logger, err := makeSpamLogger(file, db)
logger, err := makeSpamLogger(context.Background(), file, db)
require.NoError(t, err)

msg := &bot.Message{
Expand Down
19 changes: 10 additions & 9 deletions app/storage/detected_spam.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"encoding/json"
"fmt"
"log"
Expand Down Expand Up @@ -34,7 +35,7 @@ type DetectedSpamInfo struct {
}

// NewDetectedSpam creates a new DetectedSpam storage
func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) {
func NewDetectedSpam(ctx context.Context, db *sqlx.DB) (*DetectedSpam, error) {
if db == nil {
return nil, fmt.Errorf("db connection is nil")
}
Expand All @@ -55,22 +56,22 @@ func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) {
return nil, fmt.Errorf("failed to create detected_spam table: %w", err)
}

_, err = db.Exec(`ALTER TABLE detected_spam ADD COLUMN added BOOLEAN DEFAULT 0`)
_, err = db.ExecContext(ctx, `ALTER TABLE detected_spam ADD COLUMN added BOOLEAN DEFAULT 0`)
if err != nil {
if !strings.Contains(err.Error(), "duplicate column name") {
return nil, fmt.Errorf("failed to alter detected_spam table: %w", err)
}
}
// add index on timestamp
if _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_detected_spam_timestamp ON detected_spam(timestamp)`); err != nil {
if _, err = db.ExecContext(ctx, `CREATE INDEX IF NOT EXISTS idx_detected_spam_timestamp ON detected_spam(timestamp)`); err != nil {
return nil, fmt.Errorf("failed to create index on timestamp: %w", err)
}

return &DetectedSpam{db: db, lock: &sync.RWMutex{}}, nil
}

// Write adds a new detected spam entry
func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Response) error {
func (ds *DetectedSpam) Write(ctx context.Context, entry DetectedSpamInfo, checks []spamcheck.Response) error {
ds.lock.Lock()
defer ds.lock.Unlock()

Expand All @@ -80,7 +81,7 @@ func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Respons
}

query := `INSERT INTO detected_spam (text, user_id, user_name, timestamp, checks) VALUES (?, ?, ?, ?, ?)`
if _, err := ds.db.Exec(query, entry.Text, entry.UserID, entry.UserName, entry.Timestamp, checksJSON); err != nil {
if _, err := ds.db.ExecContext(ctx, query, entry.Text, entry.UserID, entry.UserName, entry.Timestamp, checksJSON); err != nil {
return fmt.Errorf("failed to insert detected spam entry: %w", err)
}

Expand All @@ -89,24 +90,24 @@ func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Respons
}

// SetAddedToSamplesFlag sets the added flag to true for the detected spam entry with the given id
func (ds *DetectedSpam) SetAddedToSamplesFlag(id int64) error {
func (ds *DetectedSpam) SetAddedToSamplesFlag(ctx context.Context, id int64) error {
ds.lock.Lock()
defer ds.lock.Unlock()

query := `UPDATE detected_spam SET added = 1 WHERE id = ?`
if _, err := ds.db.Exec(query, id); err != nil {
if _, err := ds.db.ExecContext(ctx, query, id); err != nil {
return fmt.Errorf("failed to update added to samples flag: %w", err)
}
return nil
}

// Read returns all detected spam entries
func (ds *DetectedSpam) Read() ([]DetectedSpamInfo, error) {
func (ds *DetectedSpam) Read(ctx context.Context) ([]DetectedSpamInfo, error) {
ds.lock.RLock()
defer ds.lock.RUnlock()

var entries []DetectedSpamInfo
err := ds.db.Select(&entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries)
err := ds.db.SelectContext(ctx, &entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries)
if err != nil {
return nil, fmt.Errorf("failed to get detected spam entries: %w", err)
}
Expand Down
32 changes: 19 additions & 13 deletions app/storage/detected_spam_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"encoding/json"
"testing"
"time"
Expand All @@ -17,7 +18,7 @@ func TestDetectedSpam_NewDetectedSpam(t *testing.T) {
require.NoError(t, err)
defer db.Close()

_, err = NewDetectedSpam(db)
_, err = NewDetectedSpam(context.Background(), db)
require.NoError(t, err)

var exists int
Expand All @@ -30,8 +31,9 @@ func TestDetectedSpam_Write(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
ctx := context.Background()

ds, err := NewDetectedSpam(db)
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

spamEntry := DetectedSpamInfo{
Expand All @@ -49,7 +51,7 @@ func TestDetectedSpam_Write(t *testing.T) {
},
}

err = ds.Write(spamEntry, checks)
err = ds.Write(ctx, spamEntry, checks)
require.NoError(t, err)

var count int
Expand All @@ -62,8 +64,9 @@ func TestSetAddedToSamplesFlag(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
ctx := context.Background()

ds, err := NewDetectedSpam(db)
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

spamEntry := DetectedSpamInfo{
Expand All @@ -81,14 +84,14 @@ func TestSetAddedToSamplesFlag(t *testing.T) {
},
}

err = ds.Write(spamEntry, checks)
err = ds.Write(ctx, spamEntry, checks)
require.NoError(t, err)
var added bool
err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text)
require.NoError(t, err)
assert.False(t, added)

err = ds.SetAddedToSamplesFlag(1)
err = ds.SetAddedToSamplesFlag(ctx, 1)
require.NoError(t, err)

err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text)
Expand All @@ -100,8 +103,9 @@ func TestDetectedSpam_Read(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
ctx := context.Background()

ds, err := NewDetectedSpam(db)
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

spamEntry := DetectedSpamInfo{
Expand All @@ -124,7 +128,7 @@ func TestDetectedSpam_Read(t *testing.T) {
_, err = db.Exec("INSERT INTO detected_spam (text, user_id, user_name, timestamp, checks) VALUES (?, ?, ?, ?, ?)", spamEntry.Text, spamEntry.UserID, spamEntry.UserName, spamEntry.Timestamp, checksJSON)
require.NoError(t, err)

entries, err := ds.Read()
entries, err := ds.Read(ctx)
require.NoError(t, err)
require.Len(t, entries, 1)

Expand All @@ -143,8 +147,9 @@ func TestDetectedSpam_Read_LimitExceeded(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
ctx := context.Background()

ds, err := NewDetectedSpam(db)
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

for i := 0; i < maxDetectedSpamEntries+10; i++ {
Expand All @@ -163,11 +168,11 @@ func TestDetectedSpam_Read_LimitExceeded(t *testing.T) {
},
}

err = ds.Write(spamEntry, checks)
err = ds.Write(ctx, spamEntry, checks)
require.NoError(t, err)
}

entries, err := ds.Read()
entries, err := ds.Read(ctx)
require.NoError(t, err)
assert.Len(t, entries, maxDetectedSpamEntries, "expected to retrieve only the maximum number of entries")
}
Expand All @@ -176,11 +181,12 @@ func TestDetectedSpam_Read_EmptyDB(t *testing.T) {
db, err := sqlx.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()
ctx := context.Background()

ds, err := NewDetectedSpam(db)
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

entries, err := ds.Read()
entries, err := ds.Read(ctx)
require.NoError(t, err)
assert.Empty(t, entries, "Expected no entries in an empty database")
}
40 changes: 27 additions & 13 deletions app/webapi/mocks/detected_spam.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 1281b49

Please sign in to comment.