Skip to content

Commit

Permalink
add gid and migration to locator tables
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Jan 4, 2025
1 parent 176e0a0 commit 873eb20
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ logs/
site/public/
_examples/simplechat/messages.db
*.loaded
var/
var/
154 changes: 108 additions & 46 deletions app/storage/locator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"encoding/json"
"fmt"
"log"
"strings"
"time"

"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite" // sqlite driver loaded here

"github.com/umputun/tg-spam/lib/spamcheck"
Expand Down Expand Up @@ -38,44 +40,61 @@ type SpamData struct {
Checks []spamcheck.Response
}

var locatorSchema = `
CREATE TABLE IF NOT EXISTS messages (
hash TEXT PRIMARY KEY,
gid TEXT NOT NULL DEFAULT '',
time TIMESTAMP,
chat_id INTEGER,
user_id INTEGER,
user_name TEXT,
msg_id INTEGER
);
CREATE TABLE IF NOT EXISTS spam (
user_id INTEGER PRIMARY KEY,
gid TEXT NOT NULL DEFAULT '',
time TIMESTAMP,
checks TEXT
);
CREATE INDEX IF NOT EXISTS idx_messages_user_id ON messages(user_id);
CREATE INDEX IF NOT EXISTS idx_messages_user_name ON messages(user_name);
CREATE INDEX IF NOT EXISTS idx_spam_time ON spam(time);
CREATE INDEX IF NOT EXISTS idx_messages_gid ON messages(gid);
CREATE INDEX IF NOT EXISTS idx_spam_gid ON spam(gid);
`

// NewLocator creates new Locator. ttl defines how long to keep messages in db, minSize defines the minimum number of messages to keep
func NewLocator(ctx context.Context, ttl time.Duration, minSize int, db *Engine) (*Locator, error) {
if db == nil {
return nil, fmt.Errorf("db connection is nil")
}

schema := `
CREATE TABLE IF NOT EXISTS messages (
hash TEXT PRIMARY KEY,
time TIMESTAMP,
chat_id INTEGER,
user_id INTEGER,
user_name TEXT,
msg_id INTEGER
);
CREATE TABLE IF NOT EXISTS spam (
user_id INTEGER PRIMARY KEY,
time TIMESTAMP,
checks TEXT
);
CREATE INDEX IF NOT EXISTS idx_messages_user_id ON messages(user_id);
CREATE INDEX IF NOT EXISTS idx_messages_user_name ON messages(user_name);
CREATE INDEX IF NOT EXISTS idx_spam_time ON spam(time);
`

// create schema in a single transaction
tx, err := db.Begin()
// first check if tables exist, we can't do this in a transaction because ALTER TABLE will fail
var exists int
err := db.GetContext(ctx, &exists, "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='messages'")
if err != nil {
return nil, fmt.Errorf("failed to start transaction: %w", err)
return nil, fmt.Errorf("failed to check for messages table existence: %w", err)
}
defer tx.Rollback()

if _, err = tx.ExecContext(ctx, schema); err != nil {
return nil, fmt.Errorf("failed to create schema: %w", err)
if exists == 0 { // tables do not exist, create them
tx, err := db.Begin()
if err != nil {
return nil, fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback()

if _, err = tx.ExecContext(ctx, locatorSchema); err != nil {
return nil, fmt.Errorf("failed to create schema: %w", err)
}

if err = tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
}

if err = tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
// migrate tables
if err := migrateLocator(&db.DB, db.GID()); err != nil {
return nil, fmt.Errorf("failed to migrate locator: %w", err)
}

return &Locator{ttl: ttl, minSize: minSize, db: db, RWLocker: db.MakeLock()}, nil
Expand All @@ -94,11 +113,12 @@ func (l *Locator) AddMessage(ctx context.Context, msg string, chatID, userID int
hash := l.MsgHash(msg)
log.Printf("[DEBUG] add message to locator: %q, hash:%s, userID:%d, user name:%q, chatID:%d, msgID:%d",
msg, hash, userID, userName, chatID, msgID)
_, err := l.db.NamedExecContext(ctx, `INSERT OR REPLACE INTO messages (hash, time, chat_id, user_id, user_name, msg_id)
VALUES (:hash, :time, :chat_id, :user_id, :user_name, :msg_id)`,
_, err := l.db.NamedExecContext(ctx, `INSERT OR REPLACE INTO messages (hash, gid, time, chat_id, user_id, user_name, msg_id)
VALUES (:hash, :gid, :time, :chat_id, :user_id, :user_name, :msg_id)`,
struct {
MsgMeta
Hash string `db:"hash"`
GID string `db:"gid"`
}{
MsgMeta: MsgMeta{
Time: time.Now(),
Expand All @@ -108,6 +128,7 @@ func (l *Locator) AddMessage(ctx context.Context, msg string, chatID, userID int
MsgID: msgID,
},
Hash: hash,
GID: l.db.GID(),
})
if err != nil {
return fmt.Errorf("failed to insert message: %w", err)
Expand All @@ -124,10 +145,11 @@ func (l *Locator) AddSpam(ctx context.Context, userID int64, checks []spamcheck.
if err != nil {
return fmt.Errorf("failed to marshal checks: %w", err)
}
_, err = l.db.NamedExecContext(ctx, `INSERT OR REPLACE INTO spam (user_id, time, checks)
VALUES (:user_id, :time, :checks)`,
_, err = l.db.NamedExecContext(ctx, `INSERT OR REPLACE INTO spam (user_id, gid, time, checks)
VALUES (:user_id, :gid, :time, :checks)`,
map[string]interface{}{
"user_id": userID,
"gid": l.db.GID(),
"time": time.Now(),
"checks": string(checksStr),
})
Expand All @@ -137,58 +159,59 @@ func (l *Locator) AddSpam(ctx context.Context, userID int64, checks []spamcheck.
return l.cleanupSpam()
}

// Message returns message MsgMeta for given msg
// this allows to match messages from admin chat (only text available) to the original message
// Message returns message MsgMeta for given msg and gid
func (l *Locator) Message(ctx context.Context, msg string) (MsgMeta, bool) {
l.RLock()
defer l.RUnlock()

var meta MsgMeta
hash := l.MsgHash(msg)
err := l.db.GetContext(ctx, &meta, `SELECT time, chat_id, user_id, user_name, msg_id FROM messages WHERE hash = ?`, hash)
err := l.db.GetContext(ctx, &meta, `SELECT time, chat_id, user_id, user_name, msg_id
FROM messages WHERE hash = ? AND gid = ?`, hash, l.db.GID())
if err != nil {
log.Printf("[DEBUG] failed to find message by hash %q: %v", hash, err)
return MsgMeta{}, false
}
return meta, true
}

// UserNameByID returns username by user id. Returns empty string if not found
// UserNameByID returns username by user id within the same gid
func (l *Locator) UserNameByID(ctx context.Context, userID int64) string {
l.RLock()
defer l.RUnlock()

var userName string
err := l.db.GetContext(ctx, &userName, `SELECT user_name FROM messages WHERE user_id = ? LIMIT 1`, userID)
err := l.db.GetContext(ctx, &userName, `SELECT user_name FROM messages WHERE user_id = ? AND gid = ? LIMIT 1`, userID, l.db.GID())
if err != nil {
log.Printf("[DEBUG] failed to find user name by id %d: %v", userID, err)
return ""
}
return userName
}

// UserIDByName returns user id by username. Returns 0 if not found
// UserIDByName returns user id by username within the same gid
func (l *Locator) UserIDByName(ctx context.Context, userName string) int64 {
l.RLock()
defer l.RUnlock()

var userID int64
err := l.db.GetContext(ctx, &userID, `SELECT user_id FROM messages WHERE user_name = ? LIMIT 1`, userName)
err := l.db.GetContext(ctx, &userID, `SELECT user_id FROM messages WHERE user_name = ? AND gid = ? LIMIT 1`, userName, l.db.GID())
if err != nil {
log.Printf("[DEBUG] failed to find user id by name %q: %v", userName, err)
return 0
}
return userID
}

// Spam returns message SpamData for given msg
// Spam returns message SpamData for given msg within the same gid
func (l *Locator) Spam(ctx context.Context, userID int64) (SpamData, bool) {
l.RLock()
defer l.RUnlock()

var data SpamData
var checksStr string
err := l.db.QueryRowContext(ctx, `SELECT time, checks FROM spam WHERE user_id = ?`, userID).Scan(&data.Time, &checksStr)
err := l.db.QueryRowContext(ctx, `SELECT time, checks FROM spam WHERE user_id = ? AND gid = ?`,
userID, l.db.GID()).Scan(&data.Time, &checksStr)
if err != nil {
return SpamData{}, false
}
Expand All @@ -206,20 +229,19 @@ func (l *Locator) MsgHash(msg string) string {
}

// cleanupMessages removes old messages. Messages with expired ttl are removed if the total number of messages exceeds minSize.
// The reason for minSize is to avoid removing messages on low-traffic chats where admin visits are rare.
func (l *Locator) cleanupMessages(ctx context.Context) error {
_, err := l.db.ExecContext(ctx, `DELETE FROM messages WHERE time < ? AND (SELECT COUNT(*) FROM messages) > ?`,
time.Now().Add(-l.ttl), l.minSize)
_, err := l.db.ExecContext(ctx, `DELETE FROM messages WHERE time < ? AND gid = ? AND (SELECT COUNT(*) FROM messages WHERE gid = ?) > ?`,
time.Now().Add(-l.ttl), l.db.GID(), l.db.GID(), l.minSize)
if err != nil {
return fmt.Errorf("failed to cleanup messages: %w", err)
}
return nil
}

// cleanupSpam removes old spam data
// cleanupSpam removes old spam data within the same gid
func (l *Locator) cleanupSpam() error {
_, err := l.db.Exec(`DELETE FROM spam WHERE time < ? AND (SELECT COUNT(*) FROM spam) > ?`,
time.Now().Add(-l.ttl), l.minSize)
_, err := l.db.Exec(`DELETE FROM spam WHERE time < ? AND gid = ? AND (SELECT COUNT(*) FROM spam WHERE gid = ?) > ?`,
time.Now().Add(-l.ttl), l.db.GID(), l.db.GID(), l.minSize)
if err != nil {
return fmt.Errorf("failed to cleanup spam: %w", err)
}
Expand All @@ -234,3 +256,43 @@ func (m MsgMeta) String() string {
func (s SpamData) String() string {
return fmt.Sprintf("{time: %s, checks: %+v}", s.Time.Format(time.RFC3339), s.Checks)
}

// migration function
func migrateLocator(db *sqlx.DB, gid string) error {
// add gid column to messages if it doesn't exist
if _, err := db.Exec(`ALTER TABLE messages ADD COLUMN gid TEXT DEFAULT ''`); err != nil {
if !strings.Contains(err.Error(), "duplicate column name") {
return fmt.Errorf("failed to alter messages table: %w", err)
}
}

// add gid column to spam if it doesn't exist
if _, err := db.Exec(`ALTER TABLE spam ADD COLUMN gid TEXT DEFAULT ''`); err != nil {
if !strings.Contains(err.Error(), "duplicate column name") {
return fmt.Errorf("failed to alter spam table: %w", err)
}
}

// update existing records with the provided gid
res1, err := db.Exec("UPDATE messages SET gid = ? WHERE gid = ''", gid)
if err != nil {
return fmt.Errorf("failed to update gid for existing messages: %w", err)
}
messagesAffected, err := res1.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get messages affected rows: %w", err)
}

res2, err := db.Exec("UPDATE spam SET gid = ? WHERE gid = ''", gid)
if err != nil {
return fmt.Errorf("failed to update gid for existing spam: %w", err)
}
spamAffected, err := res2.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get spam affected rows: %w", err)
}

log.Printf("[DEBUG] locator tables migrated, gid updated to %q, messages: %d, spam: %d",
gid, messagesAffected, spamAffected)
return nil
}
Loading

0 comments on commit 873eb20

Please sign in to comment.