From 1281b49405b3be095d8bf074ffaf8dc31326ac97 Mon Sep 17 00:00:00 2001 From: Umputun Date: Sun, 29 Dec 2024 02:52:24 -0600 Subject: [PATCH] add ctx support for detected spam store --- app/main.go | 10 ++++---- app/main_test.go | 2 +- app/storage/detected_spam.go | 19 ++++++++------- app/storage/detected_spam_test.go | 32 +++++++++++++++---------- app/webapi/mocks/detected_spam.go | 40 +++++++++++++++++++++---------- app/webapi/webapi.go | 10 ++++---- app/webapi/webapi_test.go | 4 ++-- 7 files changed, 69 insertions(+), 48 deletions(-) diff --git a/app/main.go b/app/main.go index 8c402b60..ef2321db 100644 --- a/app/main.go +++ b/app/main.go @@ -265,7 +265,7 @@ func execute(ctx context.Context, opts options) error { defer loggerWr.Close() // make spam logger - spamLogger, err := makeSpamLogger(loggerWr, dataDB) + spamLogger, err := makeSpamLogger(ctx, loggerWr, dataDB) if err != nil { return fmt.Errorf("can't make spam logger, %w", err) } @@ -358,7 +358,7 @@ func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc * } // make store and load approved users - detectedSpamStore, auErr := storage.NewDetectedSpam(dataDB) + detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB) if auErr != nil { return fmt.Errorf("can't make approved users store, %w", auErr) } @@ -546,9 +546,9 @@ func (n nopWriteCloser) Close() error { return nil } // makeSpamLogger creates spam logger to keep reports about spam messages // it writes json lines to the provided writer -func makeSpamLogger(wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) { +func makeSpamLogger(ctx context.Context, wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) { // make store and load approved users - detectedSpamStore, auErr := storage.NewDetectedSpam(dataDB) + detectedSpamStore, auErr := storage.NewDetectedSpam(ctx, dataDB) if auErr != nil { return nil, fmt.Errorf("can't make approved users store, %w", auErr) } @@ -587,7 +587,7 @@ func makeSpamLogger(wr io.Writer, dataDB *sqlx.DB) (events.SpamLogger, error) { UserName: msg.From.Username, Timestamp: time.Now().In(time.Local), } - if err := detectedSpamStore.Write(rec, response.CheckResults); err != nil { + if err := detectedSpamStore.Write(ctx, rec, response.CheckResults); err != nil { log.Printf("[WARN] can't write to db, %v", err) } }) diff --git a/app/main_test.go b/app/main_test.go index 3c269bfe..b5a9d8c3 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -30,7 +30,7 @@ func TestMakeSpamLogger(t *testing.T) { require.NoError(t, err) defer db.Close() - logger, err := makeSpamLogger(file, db) + logger, err := makeSpamLogger(context.Background(), file, db) require.NoError(t, err) msg := &bot.Message{ diff --git a/app/storage/detected_spam.go b/app/storage/detected_spam.go index e9d69a53..84d4570e 100644 --- a/app/storage/detected_spam.go +++ b/app/storage/detected_spam.go @@ -1,6 +1,7 @@ package storage import ( + "context" "encoding/json" "fmt" "log" @@ -34,7 +35,7 @@ type DetectedSpamInfo struct { } // NewDetectedSpam creates a new DetectedSpam storage -func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) { +func NewDetectedSpam(ctx context.Context, db *sqlx.DB) (*DetectedSpam, error) { if db == nil { return nil, fmt.Errorf("db connection is nil") } @@ -55,14 +56,14 @@ func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) { return nil, fmt.Errorf("failed to create detected_spam table: %w", err) } - _, err = db.Exec(`ALTER TABLE detected_spam ADD COLUMN added BOOLEAN DEFAULT 0`) + _, err = db.ExecContext(ctx, `ALTER TABLE detected_spam ADD COLUMN added BOOLEAN DEFAULT 0`) if err != nil { if !strings.Contains(err.Error(), "duplicate column name") { return nil, fmt.Errorf("failed to alter detected_spam table: %w", err) } } // add index on timestamp - if _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_detected_spam_timestamp ON detected_spam(timestamp)`); err != nil { + if _, err = db.ExecContext(ctx, `CREATE INDEX IF NOT EXISTS idx_detected_spam_timestamp ON detected_spam(timestamp)`); err != nil { return nil, fmt.Errorf("failed to create index on timestamp: %w", err) } @@ -70,7 +71,7 @@ func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) { } // Write adds a new detected spam entry -func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Response) error { +func (ds *DetectedSpam) Write(ctx context.Context, entry DetectedSpamInfo, checks []spamcheck.Response) error { ds.lock.Lock() defer ds.lock.Unlock() @@ -80,7 +81,7 @@ func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Respons } query := `INSERT INTO detected_spam (text, user_id, user_name, timestamp, checks) VALUES (?, ?, ?, ?, ?)` - if _, err := ds.db.Exec(query, entry.Text, entry.UserID, entry.UserName, entry.Timestamp, checksJSON); err != nil { + if _, err := ds.db.ExecContext(ctx, query, entry.Text, entry.UserID, entry.UserName, entry.Timestamp, checksJSON); err != nil { return fmt.Errorf("failed to insert detected spam entry: %w", err) } @@ -89,24 +90,24 @@ func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Respons } // SetAddedToSamplesFlag sets the added flag to true for the detected spam entry with the given id -func (ds *DetectedSpam) SetAddedToSamplesFlag(id int64) error { +func (ds *DetectedSpam) SetAddedToSamplesFlag(ctx context.Context, id int64) error { ds.lock.Lock() defer ds.lock.Unlock() query := `UPDATE detected_spam SET added = 1 WHERE id = ?` - if _, err := ds.db.Exec(query, id); err != nil { + if _, err := ds.db.ExecContext(ctx, query, id); err != nil { return fmt.Errorf("failed to update added to samples flag: %w", err) } return nil } // Read returns all detected spam entries -func (ds *DetectedSpam) Read() ([]DetectedSpamInfo, error) { +func (ds *DetectedSpam) Read(ctx context.Context) ([]DetectedSpamInfo, error) { ds.lock.RLock() defer ds.lock.RUnlock() var entries []DetectedSpamInfo - err := ds.db.Select(&entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries) + err := ds.db.SelectContext(ctx, &entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries) if err != nil { return nil, fmt.Errorf("failed to get detected spam entries: %w", err) } diff --git a/app/storage/detected_spam_test.go b/app/storage/detected_spam_test.go index 0cdd6c5e..aaf89155 100644 --- a/app/storage/detected_spam_test.go +++ b/app/storage/detected_spam_test.go @@ -1,6 +1,7 @@ package storage import ( + "context" "encoding/json" "testing" "time" @@ -17,7 +18,7 @@ func TestDetectedSpam_NewDetectedSpam(t *testing.T) { require.NoError(t, err) defer db.Close() - _, err = NewDetectedSpam(db) + _, err = NewDetectedSpam(context.Background(), db) require.NoError(t, err) var exists int @@ -30,8 +31,9 @@ func TestDetectedSpam_Write(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) defer db.Close() + ctx := context.Background() - ds, err := NewDetectedSpam(db) + ds, err := NewDetectedSpam(ctx, db) require.NoError(t, err) spamEntry := DetectedSpamInfo{ @@ -49,7 +51,7 @@ func TestDetectedSpam_Write(t *testing.T) { }, } - err = ds.Write(spamEntry, checks) + err = ds.Write(ctx, spamEntry, checks) require.NoError(t, err) var count int @@ -62,8 +64,9 @@ func TestSetAddedToSamplesFlag(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) defer db.Close() + ctx := context.Background() - ds, err := NewDetectedSpam(db) + ds, err := NewDetectedSpam(ctx, db) require.NoError(t, err) spamEntry := DetectedSpamInfo{ @@ -81,14 +84,14 @@ func TestSetAddedToSamplesFlag(t *testing.T) { }, } - err = ds.Write(spamEntry, checks) + err = ds.Write(ctx, spamEntry, checks) require.NoError(t, err) var added bool err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text) require.NoError(t, err) assert.False(t, added) - err = ds.SetAddedToSamplesFlag(1) + err = ds.SetAddedToSamplesFlag(ctx, 1) require.NoError(t, err) err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text) @@ -100,8 +103,9 @@ func TestDetectedSpam_Read(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) defer db.Close() + ctx := context.Background() - ds, err := NewDetectedSpam(db) + ds, err := NewDetectedSpam(ctx, db) require.NoError(t, err) spamEntry := DetectedSpamInfo{ @@ -124,7 +128,7 @@ func TestDetectedSpam_Read(t *testing.T) { _, err = db.Exec("INSERT INTO detected_spam (text, user_id, user_name, timestamp, checks) VALUES (?, ?, ?, ?, ?)", spamEntry.Text, spamEntry.UserID, spamEntry.UserName, spamEntry.Timestamp, checksJSON) require.NoError(t, err) - entries, err := ds.Read() + entries, err := ds.Read(ctx) require.NoError(t, err) require.Len(t, entries, 1) @@ -143,8 +147,9 @@ func TestDetectedSpam_Read_LimitExceeded(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) defer db.Close() + ctx := context.Background() - ds, err := NewDetectedSpam(db) + ds, err := NewDetectedSpam(ctx, db) require.NoError(t, err) for i := 0; i < maxDetectedSpamEntries+10; i++ { @@ -163,11 +168,11 @@ func TestDetectedSpam_Read_LimitExceeded(t *testing.T) { }, } - err = ds.Write(spamEntry, checks) + err = ds.Write(ctx, spamEntry, checks) require.NoError(t, err) } - entries, err := ds.Read() + entries, err := ds.Read(ctx) require.NoError(t, err) assert.Len(t, entries, maxDetectedSpamEntries, "expected to retrieve only the maximum number of entries") } @@ -176,11 +181,12 @@ func TestDetectedSpam_Read_EmptyDB(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) defer db.Close() + ctx := context.Background() - ds, err := NewDetectedSpam(db) + ds, err := NewDetectedSpam(ctx, db) require.NoError(t, err) - entries, err := ds.Read() + entries, err := ds.Read(ctx) require.NoError(t, err) assert.Empty(t, entries, "Expected no entries in an empty database") } diff --git a/app/webapi/mocks/detected_spam.go b/app/webapi/mocks/detected_spam.go index c3dc39ce..52e57065 100644 --- a/app/webapi/mocks/detected_spam.go +++ b/app/webapi/mocks/detected_spam.go @@ -4,6 +4,7 @@ package mocks import ( + "context" "github.com/umputun/tg-spam/app/storage" "sync" ) @@ -14,10 +15,10 @@ import ( // // // make and configure a mocked webapi.DetectedSpam // mockedDetectedSpam := &DetectedSpamMock{ -// ReadFunc: func() ([]storage.DetectedSpamInfo, error) { +// ReadFunc: func(ctx context.Context) ([]storage.DetectedSpamInfo, error) { // panic("mock out the Read method") // }, -// SetAddedToSamplesFlagFunc: func(id int64) error { +// SetAddedToSamplesFlagFunc: func(ctx context.Context, id int64) error { // panic("mock out the SetAddedToSamplesFlag method") // }, // } @@ -28,18 +29,22 @@ import ( // } type DetectedSpamMock struct { // ReadFunc mocks the Read method. - ReadFunc func() ([]storage.DetectedSpamInfo, error) + ReadFunc func(ctx context.Context) ([]storage.DetectedSpamInfo, error) // SetAddedToSamplesFlagFunc mocks the SetAddedToSamplesFlag method. - SetAddedToSamplesFlagFunc func(id int64) error + SetAddedToSamplesFlagFunc func(ctx context.Context, id int64) error // calls tracks calls to the methods. calls struct { // Read holds details about calls to the Read method. Read []struct { + // Ctx is the ctx argument value. + Ctx context.Context } // SetAddedToSamplesFlag holds details about calls to the SetAddedToSamplesFlag method. SetAddedToSamplesFlag []struct { + // Ctx is the ctx argument value. + Ctx context.Context // ID is the id argument value. ID int64 } @@ -49,16 +54,19 @@ type DetectedSpamMock struct { } // Read calls ReadFunc. -func (mock *DetectedSpamMock) Read() ([]storage.DetectedSpamInfo, error) { +func (mock *DetectedSpamMock) Read(ctx context.Context) ([]storage.DetectedSpamInfo, error) { if mock.ReadFunc == nil { panic("DetectedSpamMock.ReadFunc: method is nil but DetectedSpam.Read was just called") } callInfo := struct { - }{} + Ctx context.Context + }{ + Ctx: ctx, + } mock.lockRead.Lock() mock.calls.Read = append(mock.calls.Read, callInfo) mock.lockRead.Unlock() - return mock.ReadFunc() + return mock.ReadFunc(ctx) } // ReadCalls gets all the calls that were made to Read. @@ -66,8 +74,10 @@ func (mock *DetectedSpamMock) Read() ([]storage.DetectedSpamInfo, error) { // // len(mockedDetectedSpam.ReadCalls()) func (mock *DetectedSpamMock) ReadCalls() []struct { + Ctx context.Context } { var calls []struct { + Ctx context.Context } mock.lockRead.RLock() calls = mock.calls.Read @@ -83,19 +93,21 @@ func (mock *DetectedSpamMock) ResetReadCalls() { } // SetAddedToSamplesFlag calls SetAddedToSamplesFlagFunc. -func (mock *DetectedSpamMock) SetAddedToSamplesFlag(id int64) error { +func (mock *DetectedSpamMock) SetAddedToSamplesFlag(ctx context.Context, id int64) error { if mock.SetAddedToSamplesFlagFunc == nil { panic("DetectedSpamMock.SetAddedToSamplesFlagFunc: method is nil but DetectedSpam.SetAddedToSamplesFlag was just called") } callInfo := struct { - ID int64 + Ctx context.Context + ID int64 }{ - ID: id, + Ctx: ctx, + ID: id, } mock.lockSetAddedToSamplesFlag.Lock() mock.calls.SetAddedToSamplesFlag = append(mock.calls.SetAddedToSamplesFlag, callInfo) mock.lockSetAddedToSamplesFlag.Unlock() - return mock.SetAddedToSamplesFlagFunc(id) + return mock.SetAddedToSamplesFlagFunc(ctx, id) } // SetAddedToSamplesFlagCalls gets all the calls that were made to SetAddedToSamplesFlag. @@ -103,10 +115,12 @@ func (mock *DetectedSpamMock) SetAddedToSamplesFlag(id int64) error { // // len(mockedDetectedSpam.SetAddedToSamplesFlagCalls()) func (mock *DetectedSpamMock) SetAddedToSamplesFlagCalls() []struct { - ID int64 + Ctx context.Context + ID int64 } { var calls []struct { - ID int64 + Ctx context.Context + ID int64 } mock.lockSetAddedToSamplesFlag.RLock() calls = mock.calls.SetAddedToSamplesFlag diff --git a/app/webapi/webapi.go b/app/webapi/webapi.go index 05d6229a..79869b0b 100644 --- a/app/webapi/webapi.go +++ b/app/webapi/webapi.go @@ -113,8 +113,8 @@ type Locator interface { // DetectedSpam is a storage interface used to get detected spam messages and set added flag. type DetectedSpam interface { - Read() ([]storage.DetectedSpamInfo, error) - SetAddedToSamplesFlag(id int64) error + Read(ctx context.Context) ([]storage.DetectedSpamInfo, error) + SetAddedToSamplesFlag(ctx context.Context, id int64) error } // NewServer creates a new web API server. @@ -491,8 +491,8 @@ func (s *Server) htmlManageUsersHandler(w http.ResponseWriter, _ *http.Request) } } -func (s *Server) htmlDetectedSpamHandler(w http.ResponseWriter, _ *http.Request) { - ds, err := s.DetectedSpam.Read() +func (s *Server) htmlDetectedSpamHandler(w http.ResponseWriter, r *http.Request) { + ds, err := s.DetectedSpam.Read(r.Context()) if err != nil { log.Printf("[ERROR] Failed to fetch detected spam: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -545,7 +545,7 @@ func (s *Server) htmlAddDetectedSpamHandler(w http.ResponseWriter, r *http.Reque return } - if err := s.DetectedSpam.SetAddedToSamplesFlag(id); err != nil { + if err := s.DetectedSpam.SetAddedToSamplesFlag(r.Context(), id); err != nil { log.Printf("[WARN] failed to update detected spam: %v", err) reportErr(fmt.Errorf("can't update detected spam: %v", err), http.StatusInternalServerError) return diff --git a/app/webapi/webapi_test.go b/app/webapi/webapi_test.go index c575ecd5..cb65c11a 100644 --- a/app/webapi/webapi_test.go +++ b/app/webapi/webapi_test.go @@ -763,7 +763,7 @@ func TestServer_updateApprovedUsersHandler(t *testing.T) { func TestServer_htmlDetectedSpamHandler(t *testing.T) { calls := 0 ds := &mocks.DetectedSpamMock{ - ReadFunc: func() ([]storage.DetectedSpamInfo, error) { + ReadFunc: func(ctx context.Context) ([]storage.DetectedSpamInfo, error) { calls++ if calls > 1 { return nil, errors.New("test error") @@ -816,7 +816,7 @@ func TestServer_htmlDetectedSpamHandler(t *testing.T) { func TestServer_htmlAddDetectedSpamHandler(t *testing.T) { ds := &mocks.DetectedSpamMock{ - SetAddedToSamplesFlagFunc: func(id int64) error { + SetAddedToSamplesFlagFunc: func(ctx context.Context, id int64) error { return nil }, }