From e1c6faefa82ef3ee8109719148fdaa4de1065d24 Mon Sep 17 00:00:00 2001 From: Umputun Date: Sat, 28 Dec 2024 22:45:42 -0600 Subject: [PATCH] Refactor approved users storage methods to use context for timeout management --- .gitignore | 2 + app/main.go | 6 +- app/storage/approved_users.go | 54 ++++++-------- app/storage/approved_users_test.go | 116 ++++++++++++++++++++++------- lib/tgspam/detector.go | 54 +++++++++----- lib/tgspam/detector_test.go | 8 +- lib/tgspam/mocks/user_storage.go | 62 +++++++++------ 7 files changed, 205 insertions(+), 97 deletions(-) diff --git a/.gitignore b/.gitignore index 180af136..fbebdb07 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ data/spam-dynamic.txt data/ham-dynamic.txt data/approved-users.txt tg-spam.db +tg-spam.db-shm +tg-spam.db-wal logs/ site/public/ _examples/simplechat/messages.db \ No newline at end of file diff --git a/app/main.go b/app/main.go index 8da5fd5d..8c402b60 100644 --- a/app/main.go +++ b/app/main.go @@ -49,6 +49,7 @@ type options struct { HistoryDuration time.Duration `long:"history-duration" env:"HISTORY_DURATION" default:"24h" description:"history duration"` HistoryMinSize int `long:"history-min-size" env:"HISTORY_MIN_SIZE" default:"1000" description:"history minimal size to keep"` + StorageTimeout time.Duration `long:"storage-timeout" env:"STORAGE_TIMEOUT" default:"0s" description:"storage timeout"` Logger struct { Enabled bool `long:"enabled" env:"ENABLED" description:"enable spam rotated logs"` @@ -212,7 +213,7 @@ func execute(ctx context.Context, opts options) error { log.Printf("[DEBUG] data db: %s", dataFile) // make store and load approved users - approvedUsersStore, auErr := storage.NewApprovedUsers(dataDB) + approvedUsersStore, auErr := storage.NewApprovedUsers(ctx, dataDB) if auErr != nil { return fmt.Errorf("can't make approved users store, %w", auErr) } @@ -437,6 +438,9 @@ func makeDetector(opts options) *tgspam.Detector { detectorConfig.FirstMessageOnly = false detectorConfig.FirstMessagesCount = 0 } + if opts.StorageTimeout > 0 { // if StorageTimeout is non-zero, set it. If zero, storage timeout is disabled + detectorConfig.StorageTimeout = opts.StorageTimeout + } detector := tgspam.NewDetector(detectorConfig) log.Printf("[DEBUG] detector config: %+v", detectorConfig) diff --git a/app/storage/approved_users.go b/app/storage/approved_users.go index 34b5a031..f7c9e1a6 100644 --- a/app/storage/approved_users.go +++ b/app/storage/approved_users.go @@ -1,6 +1,7 @@ package storage import ( + "context" "fmt" "log" "sync" @@ -12,13 +13,12 @@ import ( "github.com/umputun/tg-spam/lib/approved" ) -// ApprovedUsers is a storage for approved users ids +// ApprovedUsers is a storage for approved users type ApprovedUsers struct { db *sqlx.DB lock *sync.RWMutex } -// ApprovedUsersInfo represents information about an approved user. type approvedUsersInfo struct { UserID string `db:"id"` UserName string `db:"name"` @@ -26,7 +26,7 @@ type approvedUsersInfo struct { } // NewApprovedUsers creates a new ApprovedUsers storage -func NewApprovedUsers(db *sqlx.DB) (*ApprovedUsers, error) { +func NewApprovedUsers(ctx context.Context, db *sqlx.DB) (*ApprovedUsers, error) { if db == nil { return nil, fmt.Errorf("db connection is nil") } @@ -34,34 +34,33 @@ func NewApprovedUsers(db *sqlx.DB) (*ApprovedUsers, error) { return nil, fmt.Errorf("failed to set sqlite pragma: %w", err) } - // migrate the table if necessary - err := migrateTable(db) - if err != nil { + if err := migrateTable(db); err != nil { return nil, err } - // create the table if it doesn't exist - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS approved_users ( + query := `CREATE TABLE IF NOT EXISTS approved_users ( id TEXT PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP - )`) - if err != nil { + )` + if _, err := db.ExecContext(ctx, query); err != nil { return nil, fmt.Errorf("failed to create approved_users table: %w", err) } return &ApprovedUsers{db: db, lock: &sync.RWMutex{}}, nil } -// Read returns all approved users. -func (au *ApprovedUsers) Read() ([]approved.UserInfo, error) { +// Read returns a list of all approved users +func (au *ApprovedUsers) Read(ctx context.Context) ([]approved.UserInfo, error) { au.lock.RLock() defer au.lock.RUnlock() + users := []approvedUsersInfo{} - err := au.db.Select(&users, "SELECT id, name, timestamp FROM approved_users ORDER BY timestamp DESC") + err := au.db.SelectContext(ctx, &users, "SELECT id, name, timestamp FROM approved_users ORDER BY timestamp DESC") if err != nil { return nil, fmt.Errorf("failed to get approved users: %w", err) } + res := make([]approved.UserInfo, len(users)) for i, u := range users { res[i] = approved.UserInfo{ @@ -74,35 +73,35 @@ func (au *ApprovedUsers) Read() ([]approved.UserInfo, error) { return res, nil } -// Write writes new user info to the storage -func (au *ApprovedUsers) Write(user approved.UserInfo) error { +// Write adds a new approved user +func (au *ApprovedUsers) Write(ctx context.Context, user approved.UserInfo) error { au.lock.Lock() defer au.lock.Unlock() + if user.Timestamp.IsZero() { user.Timestamp = time.Now() } - // Prepare the query to insert a new record, ignoring if it already exists + query := "INSERT OR IGNORE INTO approved_users (id, name, timestamp) VALUES (?, ?, ?)" - if _, err := au.db.Exec(query, user.UserID, user.UserName, user.Timestamp); err != nil { + if _, err := au.db.ExecContext(ctx, query, user.UserID, user.UserName, user.Timestamp); err != nil { return fmt.Errorf("failed to insert user %+v: %w", user, err) } log.Printf("[INFO] user %s added to approved users", user.String()) return nil } -// Delete deletes the given id from the storage -func (au *ApprovedUsers) Delete(id string) error { +// Delete removes an approved user by its ID +func (au *ApprovedUsers) Delete(ctx context.Context, id string) error { au.lock.Lock() defer au.lock.Unlock() - var user approvedUsersInfo - // retrieve the user's name and timestamp for logging purposes - err := au.db.Get(&user, "SELECT id, name, timestamp FROM approved_users WHERE id = ?", id) + var user approvedUsersInfo + err := au.db.GetContext(ctx, &user, "SELECT id, name, timestamp FROM approved_users WHERE id = ?", id) if err != nil { return fmt.Errorf("failed to get approved user for id %s: %w", id, err) } - if _, err := au.db.Exec("DELETE FROM approved_users WHERE id = ?", id); err != nil { + if _, err := au.db.ExecContext(ctx, "DELETE FROM approved_users WHERE id = ?", id); err != nil { return fmt.Errorf("failed to delete id %s: %w", id, err) } @@ -111,7 +110,6 @@ func (au *ApprovedUsers) Delete(id string) error { } func migrateTable(db *sqlx.DB) error { - // Check if the table exists var exists int err := db.Get(&exists, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='approved_users'") if err != nil { @@ -119,10 +117,9 @@ func migrateTable(db *sqlx.DB) error { } if exists == 0 { - return nil // Table does not exist, no need to migrate + return nil } - // Get column info for 'id' column var cols []struct { CID int `db:"cid"` Name string `db:"name"` @@ -136,12 +133,9 @@ func migrateTable(db *sqlx.DB) error { return fmt.Errorf("failed to get table info for approved_users: %w", err) } - // Check if 'id' column is INTEGER for _, col := range cols { if col.Name == "id" && col.Type == "INTEGER" { - // Drop the table if 'id' is of type INTEGER - _, err = db.Exec("DROP TABLE approved_users") - if err != nil { + if _, err = db.Exec("DROP TABLE approved_users"); err != nil { return fmt.Errorf("failed to drop old approved_users table: %w", err) } break diff --git a/app/storage/approved_users_test.go b/app/storage/approved_users_test.go index 764b88de..702a5360 100644 --- a/app/storage/approved_users_test.go +++ b/app/storage/approved_users_test.go @@ -1,6 +1,7 @@ package storage import ( + "context" "strings" "testing" "time" @@ -18,7 +19,7 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) { require.NoError(t, err) defer db.Close() - _, err = NewApprovedUsers(db) + _, err = NewApprovedUsers(context.Background(), db) require.NoError(t, err) // check if the table and columns exist @@ -41,7 +42,7 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) { _, err = db.Exec(`CREATE TABLE approved_users (id TEXT PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`) require.NoError(t, err) - _, err = NewApprovedUsers(db) + _, err = NewApprovedUsers(context.Background(), db) require.NoError(t, err) // Verify that the existing structure has not changed @@ -60,7 +61,7 @@ func TestApprovedUsers_NewApprovedUsers(t *testing.T) { _, err = db.Exec(`CREATE TABLE approved_users (id INTEGER PRIMARY KEY, name TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)`) require.NoError(t, err) - _, err = NewApprovedUsers(db) + _, err = NewApprovedUsers(context.Background(), db) require.NoError(t, err) // Verify that the new table is created with TEXT id @@ -75,7 +76,8 @@ func TestApprovedUsers_Write(t *testing.T) { db, e := sqlx.Open("sqlite", ":memory:") require.NoError(t, e) defer db.Close() - au, e := NewApprovedUsers(db) + ctx := context.Background() + au, e := NewApprovedUsers(ctx, db) require.NoError(t, e) t.Run("write new user without timestamp", func(t *testing.T) { @@ -83,7 +85,7 @@ func TestApprovedUsers_Write(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe"}) + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe"}) require.NoError(t, err) var user approvedUsersInfo @@ -96,9 +98,8 @@ func TestApprovedUsers_Write(t *testing.T) { t.Run("write new user with timestamp", func(t *testing.T) { _, err := db.Exec("DELETE FROM approved_users") require.NoError(t, err) - - require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) + ctx := context.Background() + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) require.NoError(t, err) var user approvedUsersInfo @@ -111,12 +112,12 @@ func TestApprovedUsers_Write(t *testing.T) { t.Run("update existing user", func(t *testing.T) { _, err := db.Exec("DELETE FROM approved_users") require.NoError(t, err) - + ctx := context.Background() require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe Updated"}) + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe Updated"}) require.NoError(t, err) var user approvedUsersInfo @@ -154,20 +155,20 @@ func TestApprovedUsers_StoreAndRead(t *testing.T) { expected: []string{"123", "456"}, }, } - + ctx := context.Background() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, err := NewSqliteDB(":memory:") require.NoError(t, err) - au, err := NewApprovedUsers(db) + au, err := NewApprovedUsers(ctx, db) require.NoError(t, err) for _, id := range tt.ids { - err = au.Write(approved.UserInfo{UserID: id, UserName: "name_" + id}) + err = au.Write(ctx, approved.UserInfo{UserID: id, UserName: "name_" + id}) require.NoError(t, err) } - res, err := au.Read() + res, err := au.Read(ctx) require.NoError(t, err) assert.Equal(t, len(tt.expected), len(res)) }) @@ -177,13 +178,14 @@ func TestApprovedUsers_StoreAndRead(t *testing.T) { func TestApprovedUsers_Read(t *testing.T) { db, e := NewSqliteDB(":memory:") require.NoError(t, e) - au, e := NewApprovedUsers(db) + ctx := context.Background() + au, e := NewApprovedUsers(ctx, db) require.NoError(t, e) t.Run("empty", func(t *testing.T) { _, err := db.Exec("DELETE FROM approved_users") require.NoError(t, err) - users, err := au.Read() + users, err := au.Read(ctx) require.NoError(t, err) assert.Equal(t, []approved.UserInfo{}, users) }) @@ -200,7 +202,7 @@ func TestApprovedUsers_Read(t *testing.T) { "456", "Jane Doe", time.Date(2023, 10, 3, 0, 0, 0, 0, time.UTC)) require.NoError(t, err) - users, err := au.Read() + users, err := au.Read(ctx) require.NoError(t, err) assert.Equal(t, []approved.UserInfo{ {UserID: "456", UserName: "Jane Doe", Timestamp: time.Date(2023, 10, 3, 0, 0, 0, 0, time.UTC)}, @@ -213,13 +215,13 @@ func TestApprovedUsers_Read(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe", Timestamp: time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)}) require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "456", UserName: "Jane Doe", Timestamp: time.Date(2023, 10, 3, 0, 0, 0, 0, time.UTC)}) + err = au.Write(ctx, approved.UserInfo{UserID: "456", UserName: "Jane Doe", Timestamp: time.Date(2023, 10, 3, 0, 0, 0, 0, time.UTC)}) require.NoError(t, err) - users, err := au.Read() + users, err := au.Read(ctx) require.NoError(t, err) assert.Equal(t, []approved.UserInfo{ {UserID: "456", UserName: "Jane Doe", Timestamp: time.Date(2023, 10, 3, 0, 0, 0, 0, time.UTC)}, @@ -231,7 +233,8 @@ func TestApprovedUsers_Read(t *testing.T) { func TestApprovedUsers_Delete(t *testing.T) { db, e := NewSqliteDB(":memory:") require.NoError(t, e) - au, e := NewApprovedUsers(db) + ctx := context.Background() + au, e := NewApprovedUsers(ctx, db) require.NoError(t, e) t.Run("delete existing user", func(t *testing.T) { @@ -239,10 +242,10 @@ func TestApprovedUsers_Delete(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - err = au.Write(approved.UserInfo{UserID: "123", UserName: "John Doe"}) + err = au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "John Doe"}) require.NoError(t, err) - err = au.Delete("123") + err = au.Delete(ctx, "123") require.NoError(t, err) var user approvedUsersInfo @@ -255,7 +258,70 @@ func TestApprovedUsers_Delete(t *testing.T) { _, err := db.Exec("DELETE FROM approved_users") require.NoError(t, err) - err = au.Delete("123") + err = au.Delete(ctx, "123") + require.Error(t, err) + }) +} + +func TestApprovedUsers_ContextCancellation(t *testing.T) { + db, err := NewSqliteDB(":memory:") + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + au, err := NewApprovedUsers(ctx, db) + require.NoError(t, err) + + t.Run("new with cancelled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := NewApprovedUsers(ctx, db) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("read with cancelled context", func(t *testing.T) { + // prepare data + err := au.Write(ctx, approved.UserInfo{UserID: "123", UserName: "test"}) + require.NoError(t, err) + + ctxCanceled, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = au.Read(ctxCanceled) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("write with cancelled context", func(t *testing.T) { + ctxCanceled, cancel := context.WithCancel(context.Background()) + cancel() + + err := au.Write(ctxCanceled, approved.UserInfo{UserID: "456", UserName: "test"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("delete with cancelled context", func(t *testing.T) { + // prepare data + err := au.Write(ctx, approved.UserInfo{UserID: "789", UserName: "test"}) + require.NoError(t, err) + + ctxCanceled, cancel := context.WithCancel(context.Background()) + cancel() + + err = au.Delete(ctxCanceled, "789") + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("context timeout", func(t *testing.T) { + ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + time.Sleep(time.Millisecond) + + err := au.Write(ctxTimeout, approved.UserInfo{UserID: "timeout", UserName: "test"}) require.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") }) } diff --git a/lib/tgspam/detector.go b/lib/tgspam/detector.go index 02ca4257..08a0550a 100644 --- a/lib/tgspam/detector.go +++ b/lib/tgspam/detector.go @@ -3,6 +3,7 @@ package tgspam import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -46,16 +47,17 @@ type Detector struct { // Config is a set of parameters for Detector. type Config struct { - SimilarityThreshold float64 // threshold for spam similarity, 0.0 - 1.0 - MinMsgLen int // minimum message length to check - MaxAllowedEmoji int // maximum number of emojis allowed in a message - CasAPI string // CAS API URL - FirstMessageOnly bool // if true, only the first message from a user is checked - FirstMessagesCount int // number of first messages to check for spam - HTTPClient HTTPClient // http client to use for requests - MinSpamProbability float64 // minimum spam probability to consider a message spam with classifier, if 0 - ignored - OpenAIVeto bool // if true, openai will be used to veto spam messages, otherwise it will be used to veto ham messages - MultiLangWords int // if true, check for number of multi-lingual words + SimilarityThreshold float64 // threshold for spam similarity, 0.0 - 1.0 + MinMsgLen int // minimum message length to check + MaxAllowedEmoji int // maximum number of emojis allowed in a message + CasAPI string // CAS API URL + FirstMessageOnly bool // if true, only the first message from a user is checked + FirstMessagesCount int // number of first messages to check for spam + HTTPClient HTTPClient // http client to use for requests + MinSpamProbability float64 // minimum spam probability to consider a message spam with classifier, if 0 - ignored + OpenAIVeto bool // if true, openai will be used to veto spam messages, otherwise it will be used to veto ham messages + MultiLangWords int // if true, check for number of multi-lingual words + StorageTimeout time.Duration // timeout for storage operations, if not set - no timeout } // SampleUpdater is an interface for updating spam/ham samples on the fly. @@ -66,9 +68,9 @@ type SampleUpdater interface { // UserStorage is an interface for approved users storage. type UserStorage interface { - Read() ([]approved.UserInfo, error) // read approved users from storage - Write(au approved.UserInfo) error // write approved user to storage - Delete(id string) error // delete approved user from storage + Read(ctx context.Context) ([]approved.UserInfo, error) // read approved users from storage + Write(ctx context.Context, au approved.UserInfo) error // write approved user to storage + Delete(ctx context.Context, id string) error // delete approved user from storage } // HTTPClient is an interface for http client, satisfied by http.Client. @@ -201,11 +203,14 @@ func (d *Detector) Check(req spamcheck.Request) (spam bool, cr []spamcheck.Respo } if d.FirstMessageOnly || d.FirstMessagesCount > 0 { + ctx, cancel := d.ctxWithStoreTimeout() + defer cancel() + au := approved.UserInfo{Count: d.approvedUsers[req.UserID].Count + 1, UserID: req.UserID, UserName: req.UserName, Timestamp: time.Now()} d.approvedUsers[req.UserID] = au if d.userStorage != nil && !req.CheckOnly { - _ = d.userStorage.Write(au) // ignore error, failed to write to storage is not critical here + _ = d.userStorage.Write(ctx, au) // ignore error, failed to write to storage is not critical here } } return false, cr @@ -234,7 +239,11 @@ func (d *Detector) WithUserStorage(storage UserStorage) (count int, err error) { defer d.lock.Unlock() d.approvedUsers = make(map[string]approved.UserInfo) // reset approved users d.userStorage = storage - users, err := d.userStorage.Read() + + ctx, cancel := d.ctxWithStoreTimeout() + defer cancel() + + users, err := d.userStorage.Read(ctx) if err != nil { return 0, fmt.Errorf("failed to read approved users from storage: %w", err) } @@ -298,7 +307,9 @@ func (d *Detector) AddApprovedUser(user approved.UserInfo) error { } if d.userStorage != nil { - if err := d.userStorage.Write(user); err != nil { + ctx, cancel := d.ctxWithStoreTimeout() + defer cancel() + if err := d.userStorage.Write(ctx, user); err != nil { return fmt.Errorf("failed to write approved user %+v to storage: %w", user, err) } } @@ -311,7 +322,9 @@ func (d *Detector) RemoveApprovedUser(id string) error { defer d.lock.Unlock() delete(d.approvedUsers, id) if d.userStorage != nil { - if err := d.userStorage.Delete(id); err != nil { + ctx, cancel := d.ctxWithStoreTimeout() + defer cancel() + if err := d.userStorage.Delete(ctx, id); err != nil { return fmt.Errorf("failed to delete approved user %s from storage: %w", id, err) } } @@ -670,3 +683,10 @@ func (d *Detector) cleanText(text string) string { } return result.String() } + +func (d *Detector) ctxWithStoreTimeout() (context.Context, context.CancelFunc) { + if d.StorageTimeout == 0 { + return context.Background(), func() {} + } + return context.WithTimeout(context.Background(), d.StorageTimeout) +} diff --git a/lib/tgspam/detector_test.go b/lib/tgspam/detector_test.go index 24505bc0..f29e79b0 100644 --- a/lib/tgspam/detector_test.go +++ b/lib/tgspam/detector_test.go @@ -861,9 +861,11 @@ func TestDetector_FirstMessagesCount(t *testing.T) { func TestDetector_ApprovedUsers(t *testing.T) { mockUserStore := &mocks.UserStorageMock{ - ReadFunc: func() ([]approved.UserInfo, error) { return []approved.UserInfo{{UserID: "123"}, {UserID: "456"}}, nil }, - WriteFunc: func(au approved.UserInfo) error { return nil }, - DeleteFunc: func(id string) error { return nil }, + ReadFunc: func(context.Context) ([]approved.UserInfo, error) { + return []approved.UserInfo{{UserID: "123"}, {UserID: "456"}}, nil + }, + WriteFunc: func(_ context.Context, au approved.UserInfo) error { return nil }, + DeleteFunc: func(_ context.Context, id string) error { return nil }, } t.Run("load with storage", func(t *testing.T) { diff --git a/lib/tgspam/mocks/user_storage.go b/lib/tgspam/mocks/user_storage.go index c14818a5..416d92e9 100644 --- a/lib/tgspam/mocks/user_storage.go +++ b/lib/tgspam/mocks/user_storage.go @@ -4,6 +4,7 @@ package mocks import ( + "context" "github.com/umputun/tg-spam/lib/approved" "sync" ) @@ -14,13 +15,13 @@ import ( // // // make and configure a mocked tgspam.UserStorage // mockedUserStorage := &UserStorageMock{ -// DeleteFunc: func(id string) error { +// DeleteFunc: func(ctx context.Context, id string) error { // panic("mock out the Delete method") // }, -// ReadFunc: func() ([]approved.UserInfo, error) { +// ReadFunc: func(ctx context.Context) ([]approved.UserInfo, error) { // panic("mock out the Read method") // }, -// WriteFunc: func(au approved.UserInfo) error { +// WriteFunc: func(ctx context.Context, au approved.UserInfo) error { // panic("mock out the Write method") // }, // } @@ -31,26 +32,32 @@ import ( // } type UserStorageMock struct { // DeleteFunc mocks the Delete method. - DeleteFunc func(id string) error + DeleteFunc func(ctx context.Context, id string) error // ReadFunc mocks the Read method. - ReadFunc func() ([]approved.UserInfo, error) + ReadFunc func(ctx context.Context) ([]approved.UserInfo, error) // WriteFunc mocks the Write method. - WriteFunc func(au approved.UserInfo) error + WriteFunc func(ctx context.Context, au approved.UserInfo) error // calls tracks calls to the methods. calls struct { // Delete holds details about calls to the Delete method. Delete []struct { + // Ctx is the ctx argument value. + Ctx context.Context // ID is the id argument value. ID string } // Read holds details about calls to the Read method. Read []struct { + // Ctx is the ctx argument value. + Ctx context.Context } // Write holds details about calls to the Write method. Write []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Au is the au argument value. Au approved.UserInfo } @@ -61,19 +68,21 @@ type UserStorageMock struct { } // Delete calls DeleteFunc. -func (mock *UserStorageMock) Delete(id string) error { +func (mock *UserStorageMock) Delete(ctx context.Context, id string) error { if mock.DeleteFunc == nil { panic("UserStorageMock.DeleteFunc: method is nil but UserStorage.Delete was just called") } callInfo := struct { - ID string + Ctx context.Context + ID string }{ - ID: id, + Ctx: ctx, + ID: id, } mock.lockDelete.Lock() mock.calls.Delete = append(mock.calls.Delete, callInfo) mock.lockDelete.Unlock() - return mock.DeleteFunc(id) + return mock.DeleteFunc(ctx, id) } // DeleteCalls gets all the calls that were made to Delete. @@ -81,10 +90,12 @@ func (mock *UserStorageMock) Delete(id string) error { // // len(mockedUserStorage.DeleteCalls()) func (mock *UserStorageMock) DeleteCalls() []struct { - ID string + Ctx context.Context + ID string } { var calls []struct { - ID string + Ctx context.Context + ID string } mock.lockDelete.RLock() calls = mock.calls.Delete @@ -100,16 +111,19 @@ func (mock *UserStorageMock) ResetDeleteCalls() { } // Read calls ReadFunc. -func (mock *UserStorageMock) Read() ([]approved.UserInfo, error) { +func (mock *UserStorageMock) Read(ctx context.Context) ([]approved.UserInfo, error) { if mock.ReadFunc == nil { panic("UserStorageMock.ReadFunc: method is nil but UserStorage.Read was just called") } callInfo := struct { - }{} + Ctx context.Context + }{ + Ctx: ctx, + } mock.lockRead.Lock() mock.calls.Read = append(mock.calls.Read, callInfo) mock.lockRead.Unlock() - return mock.ReadFunc() + return mock.ReadFunc(ctx) } // ReadCalls gets all the calls that were made to Read. @@ -117,8 +131,10 @@ func (mock *UserStorageMock) Read() ([]approved.UserInfo, error) { // // len(mockedUserStorage.ReadCalls()) func (mock *UserStorageMock) ReadCalls() []struct { + Ctx context.Context } { var calls []struct { + Ctx context.Context } mock.lockRead.RLock() calls = mock.calls.Read @@ -134,19 +150,21 @@ func (mock *UserStorageMock) ResetReadCalls() { } // Write calls WriteFunc. -func (mock *UserStorageMock) Write(au approved.UserInfo) error { +func (mock *UserStorageMock) Write(ctx context.Context, au approved.UserInfo) error { if mock.WriteFunc == nil { panic("UserStorageMock.WriteFunc: method is nil but UserStorage.Write was just called") } callInfo := struct { - Au approved.UserInfo + Ctx context.Context + Au approved.UserInfo }{ - Au: au, + Ctx: ctx, + Au: au, } mock.lockWrite.Lock() mock.calls.Write = append(mock.calls.Write, callInfo) mock.lockWrite.Unlock() - return mock.WriteFunc(au) + return mock.WriteFunc(ctx, au) } // WriteCalls gets all the calls that were made to Write. @@ -154,10 +172,12 @@ func (mock *UserStorageMock) Write(au approved.UserInfo) error { // // len(mockedUserStorage.WriteCalls()) func (mock *UserStorageMock) WriteCalls() []struct { - Au approved.UserInfo + Ctx context.Context + Au approved.UserInfo } { var calls []struct { - Au approved.UserInfo + Ctx context.Context + Au approved.UserInfo } mock.lockWrite.RLock() calls = mock.calls.Write