diff --git a/app/main.go b/app/main.go index 823eff9e..780b99d2 100644 --- a/app/main.go +++ b/app/main.go @@ -363,15 +363,15 @@ func activateServer(ctx context.Context, opts options, sf *bot.SpamFilter, loc * } srv := webapi.Server{Config: webapi.Config{ - ListenAddr: opts.Server.ListenAddr, - Detector: sf.Detector, - SpamFilter: sf, - Locator: loc, - DetectedSpamReader: detectedSpamStore, - AuthPasswd: authPassswd, - Version: revision, - Dbg: opts.Dbg, - Settings: settings, + ListenAddr: opts.Server.ListenAddr, + Detector: sf.Detector, + SpamFilter: sf, + Locator: loc, + DetectedSpam: detectedSpamStore, + AuthPasswd: authPassswd, + Version: revision, + Dbg: opts.Dbg, + Settings: settings, }} go func() { diff --git a/app/storage/detected_spam.go b/app/storage/detected_spam.go index 8afabb10..95923bc5 100644 --- a/app/storage/detected_spam.go +++ b/app/storage/detected_spam.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log" + "strings" "time" "github.com/jmoiron/sqlx" @@ -20,10 +21,12 @@ type DetectedSpam struct { // DetectedSpamInfo represents information about a detected spam entry. type DetectedSpamInfo struct { + ID int64 `db:"id"` Text string `db:"text"` UserID int64 `db:"user_id"` UserName string `db:"user_name"` Timestamp time.Time `db:"timestamp"` + Added bool `db:"added"` // added to samples ChecksJSON string `db:"checks"` // Store as JSON Checks []spamcheck.Response `db:"-"` // Don't store in DB } @@ -31,17 +34,24 @@ type DetectedSpamInfo struct { // NewDetectedSpam creates a new DetectedSpam storage func NewDetectedSpam(db *sqlx.DB) (*DetectedSpam, error) { _, err := db.Exec(`CREATE TABLE IF NOT EXISTS detected_spam ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - text TEXT, - user_id INTEGER, - user_name TEXT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - checks TEXT - )`) + id INTEGER PRIMARY KEY AUTOINCREMENT, + text TEXT, + user_id INTEGER, + user_name TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + added BOOLEAN DEFAULT 0, + checks TEXT + )`) if err != nil { return nil, fmt.Errorf("failed to create detected_spam table: %w", err) } + _, err = db.Exec(`ALTER TABLE detected_spam ADD COLUMN added BOOLEAN DEFAULT 0`) + if err != nil { + if !strings.Contains(err.Error(), "duplicate column name") { + return nil, fmt.Errorf("failed to alter detected_spam table: %w", err) + } + } // add index on timestamp if _, err = db.Exec(`CREATE INDEX IF NOT EXISTS idx_detected_spam_timestamp ON detected_spam(timestamp)`); err != nil { return nil, fmt.Errorf("failed to create index on timestamp: %w", err) @@ -66,10 +76,19 @@ func (ds *DetectedSpam) Write(entry DetectedSpamInfo, checks []spamcheck.Respons return nil } +// SetAddedToSamplesFlag sets the added flag to true for the detected spam entry with the given id +func (ds *DetectedSpam) SetAddedToSamplesFlag(id int64) error { + query := `UPDATE detected_spam SET added = 1 WHERE id = ?` + if _, err := ds.db.Exec(query, id); err != nil { + return fmt.Errorf("failed to update added to samples flag: %w", err) + } + return nil +} + // Read returns all detected spam entries func (ds *DetectedSpam) Read() ([]DetectedSpamInfo, error) { var entries []DetectedSpamInfo - err := ds.db.Select(&entries, "SELECT text, user_id, user_name, timestamp, checks FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries) + err := ds.db.Select(&entries, "SELECT * FROM detected_spam ORDER BY timestamp DESC LIMIT ?", maxDetectedSpamEntries) if err != nil { return nil, fmt.Errorf("failed to get detected spam entries: %w", err) } diff --git a/app/storage/detected_spam_test.go b/app/storage/detected_spam_test.go index 26e258fd..0cdd6c5e 100644 --- a/app/storage/detected_spam_test.go +++ b/app/storage/detected_spam_test.go @@ -58,6 +58,44 @@ func TestDetectedSpam_Write(t *testing.T) { assert.Equal(t, 1, count) } +func TestSetAddedToSamplesFlag(t *testing.T) { + db, err := sqlx.Open("sqlite", ":memory:") + require.NoError(t, err) + defer db.Close() + + ds, err := NewDetectedSpam(db) + require.NoError(t, err) + + spamEntry := DetectedSpamInfo{ + Text: "spam message", + UserID: 1, + UserName: "Spammer", + Timestamp: time.Now(), + } + + checks := []spamcheck.Response{ + { + Name: "Check1", + Spam: true, + Details: "Details 1", + }, + } + + err = ds.Write(spamEntry, checks) + require.NoError(t, err) + var added bool + err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text) + require.NoError(t, err) + assert.False(t, added) + + err = ds.SetAddedToSamplesFlag(1) + require.NoError(t, err) + + err = db.Get(&added, "SELECT added FROM detected_spam WHERE text = ?", spamEntry.Text) + require.NoError(t, err) + assert.True(t, added) +} + func TestDetectedSpam_Read(t *testing.T) { db, err := sqlx.Open("sqlite", ":memory:") require.NoError(t, err) diff --git a/app/webapi/assets/components/heads.html b/app/webapi/assets/components/heads.html new file mode 100644 index 00000000..7e8329d8 --- /dev/null +++ b/app/webapi/assets/components/heads.html @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/app/webapi/assets/detected_spam.html b/app/webapi/assets/detected_spam.html index e9d2d259..bca892ef 100644 --- a/app/webapi/assets/detected_spam.html +++ b/app/webapi/assets/detected_spam.html @@ -2,10 +2,7 @@ Detected Spam - TG-Spam - - - - + {{template "heads.html"}} {{template "navbar.html"}} @@ -13,6 +10,7 @@
+

Detected Spam ({{.TotalDetectedSpam}})

@@ -26,6 +24,9 @@

Detected Spam ({{.TotalDetectedSpam}})

{{range .DetectedSpamEntries}} + {{$text := .Text}} + {{$id := .ID}} + {{$added := .Added}} @@ -33,10 +34,22 @@

Detected Spam ({{.TotalDetectedSpam}})

{{else}} diff --git a/app/webapi/assets/manage_samples.html b/app/webapi/assets/manage_samples.html index dbeeb7d4..59730af8 100644 --- a/app/webapi/assets/manage_samples.html +++ b/app/webapi/assets/manage_samples.html @@ -2,10 +2,7 @@ Manage Samples - TG-Spam - - - - + {{template "heads.html"}} {{template "navbar.html"}} diff --git a/app/webapi/assets/manage_users.html b/app/webapi/assets/manage_users.html index 59380d7f..15f920db 100644 --- a/app/webapi/assets/manage_users.html +++ b/app/webapi/assets/manage_users.html @@ -2,10 +2,7 @@ Manage Users - TG-Spam - - - - + {{template "heads.html"}} {{template "navbar.html"}} diff --git a/app/webapi/assets/settings.html b/app/webapi/assets/settings.html index 6ff2040a..407cd613 100644 --- a/app/webapi/assets/settings.html +++ b/app/webapi/assets/settings.html @@ -2,9 +2,7 @@ Settings - TG-Spam - - - + {{template "heads.html"}} diff --git a/app/webapi/mocks/detected_spam.go b/app/webapi/mocks/detected_spam.go index adbbbc5f..c3dc39ce 100644 --- a/app/webapi/mocks/detected_spam.go +++ b/app/webapi/mocks/detected_spam.go @@ -8,38 +8,50 @@ import ( "sync" ) -// DetectedSpamReaderMock is a mock implementation of webapi.DetectedSpamReader. +// DetectedSpamMock is a mock implementation of webapi.DetectedSpam. // -// func TestSomethingThatUsesDetectedSpamReader(t *testing.T) { +// func TestSomethingThatUsesDetectedSpam(t *testing.T) { // -// // make and configure a mocked webapi.DetectedSpamReader -// mockedDetectedSpamReader := &DetectedSpamReaderMock{ +// // make and configure a mocked webapi.DetectedSpam +// mockedDetectedSpam := &DetectedSpamMock{ // ReadFunc: func() ([]storage.DetectedSpamInfo, error) { // panic("mock out the Read method") // }, +// SetAddedToSamplesFlagFunc: func(id int64) error { +// panic("mock out the SetAddedToSamplesFlag method") +// }, // } // -// // use mockedDetectedSpamReader in code that requires webapi.DetectedSpamReader +// // use mockedDetectedSpam in code that requires webapi.DetectedSpam // // and then make assertions. // // } -type DetectedSpamReaderMock struct { +type DetectedSpamMock struct { // ReadFunc mocks the Read method. ReadFunc func() ([]storage.DetectedSpamInfo, error) + // SetAddedToSamplesFlagFunc mocks the SetAddedToSamplesFlag method. + SetAddedToSamplesFlagFunc func(id int64) error + // calls tracks calls to the methods. calls struct { // Read holds details about calls to the Read method. Read []struct { } + // SetAddedToSamplesFlag holds details about calls to the SetAddedToSamplesFlag method. + SetAddedToSamplesFlag []struct { + // ID is the id argument value. + ID int64 + } } - lockRead sync.RWMutex + lockRead sync.RWMutex + lockSetAddedToSamplesFlag sync.RWMutex } // Read calls ReadFunc. -func (mock *DetectedSpamReaderMock) Read() ([]storage.DetectedSpamInfo, error) { +func (mock *DetectedSpamMock) Read() ([]storage.DetectedSpamInfo, error) { if mock.ReadFunc == nil { - panic("DetectedSpamReaderMock.ReadFunc: method is nil but DetectedSpamReader.Read was just called") + panic("DetectedSpamMock.ReadFunc: method is nil but DetectedSpam.Read was just called") } callInfo := struct { }{} @@ -52,8 +64,8 @@ func (mock *DetectedSpamReaderMock) Read() ([]storage.DetectedSpamInfo, error) { // ReadCalls gets all the calls that were made to Read. // Check the length with: // -// len(mockedDetectedSpamReader.ReadCalls()) -func (mock *DetectedSpamReaderMock) ReadCalls() []struct { +// len(mockedDetectedSpam.ReadCalls()) +func (mock *DetectedSpamMock) ReadCalls() []struct { } { var calls []struct { } @@ -64,15 +76,58 @@ func (mock *DetectedSpamReaderMock) ReadCalls() []struct { } // ResetReadCalls reset all the calls that were made to Read. -func (mock *DetectedSpamReaderMock) ResetReadCalls() { +func (mock *DetectedSpamMock) ResetReadCalls() { mock.lockRead.Lock() mock.calls.Read = nil mock.lockRead.Unlock() } +// SetAddedToSamplesFlag calls SetAddedToSamplesFlagFunc. +func (mock *DetectedSpamMock) SetAddedToSamplesFlag(id int64) error { + if mock.SetAddedToSamplesFlagFunc == nil { + panic("DetectedSpamMock.SetAddedToSamplesFlagFunc: method is nil but DetectedSpam.SetAddedToSamplesFlag was just called") + } + callInfo := struct { + ID int64 + }{ + ID: id, + } + mock.lockSetAddedToSamplesFlag.Lock() + mock.calls.SetAddedToSamplesFlag = append(mock.calls.SetAddedToSamplesFlag, callInfo) + mock.lockSetAddedToSamplesFlag.Unlock() + return mock.SetAddedToSamplesFlagFunc(id) +} + +// SetAddedToSamplesFlagCalls gets all the calls that were made to SetAddedToSamplesFlag. +// Check the length with: +// +// len(mockedDetectedSpam.SetAddedToSamplesFlagCalls()) +func (mock *DetectedSpamMock) SetAddedToSamplesFlagCalls() []struct { + ID int64 +} { + var calls []struct { + ID int64 + } + mock.lockSetAddedToSamplesFlag.RLock() + calls = mock.calls.SetAddedToSamplesFlag + mock.lockSetAddedToSamplesFlag.RUnlock() + return calls +} + +// ResetSetAddedToSamplesFlagCalls reset all the calls that were made to SetAddedToSamplesFlag. +func (mock *DetectedSpamMock) ResetSetAddedToSamplesFlagCalls() { + mock.lockSetAddedToSamplesFlag.Lock() + mock.calls.SetAddedToSamplesFlag = nil + mock.lockSetAddedToSamplesFlag.Unlock() +} + // ResetCalls reset all the calls that were made to all mocked methods. -func (mock *DetectedSpamReaderMock) ResetCalls() { +func (mock *DetectedSpamMock) ResetCalls() { mock.lockRead.Lock() mock.calls.Read = nil mock.lockRead.Unlock() + + mock.lockSetAddedToSamplesFlag.Lock() + mock.calls.SetAddedToSamplesFlag = nil + mock.lockSetAddedToSamplesFlag.Unlock() } diff --git a/app/webapi/webapi.go b/app/webapi/webapi.go index 4149ef4b..29202a08 100644 --- a/app/webapi/webapi.go +++ b/app/webapi/webapi.go @@ -31,7 +31,7 @@ import ( //go:generate moq --out mocks/detector.go --pkg mocks --with-resets --skip-ensure . Detector //go:generate moq --out mocks/spam_filter.go --pkg mocks --with-resets --skip-ensure . SpamFilter //go:generate moq --out mocks/locator.go --pkg mocks --with-resets --skip-ensure . Locator -//go:generate moq --out mocks/detected_spam.go --pkg mocks --with-resets --skip-ensure . DetectedSpamReader +//go:generate moq --out mocks/detected_spam.go --pkg mocks --with-resets --skip-ensure . DetectedSpam //go:embed assets/* assets/components/* var templateFS embed.FS @@ -43,15 +43,15 @@ type Server struct { // Config defines server parameters type Config struct { - Version string // version to show in /ping - ListenAddr string // listen address - Detector Detector // spam detector - SpamFilter SpamFilter // spam filter (bot) - DetectedSpamReader DetectedSpamReader // detected spam reader from storage - Locator Locator // locator for user info - AuthPasswd string // basic auth password for user "tg-spam" - Dbg bool // debug mode - Settings Settings // application settings + Version string // version to show in /ping + ListenAddr string // listen address + Detector Detector // spam detector + SpamFilter SpamFilter // spam filter (bot) + DetectedSpam DetectedSpam // detected spam accessor + Locator Locator // locator for user info + AuthPasswd string // basic auth password for user "tg-spam" + Dbg bool // debug mode + Settings Settings // application settings } // Settings contains all application settings @@ -104,9 +104,10 @@ type Locator interface { UserNameByID(userID int64) string } -// DetectedSpamReader is a storage interface used to get detected spam messages. -type DetectedSpamReader interface { +// DetectedSpam is a storage interface used to get detected spam messages and set added flag. +type DetectedSpam interface { Read() ([]storage.DetectedSpamInfo, error) + SetAddedToSamplesFlag(id int64) error } // NewServer creates a new web API server. @@ -190,13 +191,14 @@ func (s *Server) routes(router *chi.Mux) *chi.Mux { router.Group(func(webUI chi.Router) { webUI.Use(s.authMiddleware(rest.BasicAuthWithPrompt("tg-spam", s.AuthPasswd))) - webUI.Get("/", s.htmlSpamCheckHandler) // serve template for webUI UI - webUI.Get("/manage_samples", s.htmlManageSamplesHandler) // serve manage samples page - webUI.Get("/manage_users", s.htmlManageUsersHandler) // serve manage users page - webUI.Get("/detected_spam", s.htmlDetectedSpamHandler) // serve detected spam page - webUI.Get("/list_settings", s.htmlSettingsHandler) // serve settings - webUI.Get("/styles.css", s.stylesHandler) // serve styles.css - webUI.Get("/logo.png", s.logoutHandler) // serve logo.png + webUI.Get("/", s.htmlSpamCheckHandler) // serve template for webUI UI + webUI.Get("/manage_samples", s.htmlManageSamplesHandler) // serve manage samples page + webUI.Get("/manage_users", s.htmlManageUsersHandler) // serve manage users page + webUI.Get("/detected_spam", s.htmlDetectedSpamHandler) // serve detected spam page + webUI.Get("/list_settings", s.htmlSettingsHandler) // serve settings + webUI.Get("/styles.css", s.stylesHandler) // serve styles.css + webUI.Get("/logo.png", s.logoutHandler) // serve logo.png + webUI.Post("/detected_spam/add", s.htmlAddDetectedSpamHandler) // add detected spam to samples }) return router @@ -452,7 +454,8 @@ func (s *Server) getApprovedUsersHandler(w http.ResponseWriter, _ *http.Request) // htmlSpamCheckHandler handles GET / request. // It returns rendered spam_check.html template with all the components. func (s *Server) htmlSpamCheckHandler(w http.ResponseWriter, _ *http.Request) { - tmpl, err := template.New("").ParseFS(templateFS, "assets/spam_check.html", "assets/components/navbar.html") + tmpl, err := template.New("").ParseFS(templateFS, + "assets/spam_check.html", "assets/components/heads.html", "assets/components/navbar.html") if err != nil { log.Printf("[WARN] can't load template: %v", err) http.Error(w, "Error loading template", http.StatusInternalServerError) @@ -497,7 +500,8 @@ func (s *Server) htmlManageSamplesHandler(w http.ResponseWriter, _ *http.Request // Parse the navbar and manage_samples templates tmpl, err := template.New("").ParseFS(templateFS, - "assets/manage_samples.html", "assets/components/navbar.html", "assets/components/samples_list.html") + "assets/manage_samples.html", "assets/components/heads.html", + "assets/components/navbar.html", "assets/components/samples_list.html") if err != nil { log.Printf("[WARN] failed to parse templates: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -514,7 +518,7 @@ func (s *Server) htmlManageSamplesHandler(w http.ResponseWriter, _ *http.Request func (s *Server) htmlManageUsersHandler(w http.ResponseWriter, _ *http.Request) { tmpl, err := template.New("").ParseFS(templateFS, "assets/manage_users.html", - "assets/components/navbar.html", "assets/components/users_list.html") + "assets/components/heads.html", "assets/components/navbar.html", "assets/components/users_list.html") if err != nil { log.Printf("[WARN] can't load template: %v", err) http.Error(w, "Error loading template", http.StatusInternalServerError) @@ -539,14 +543,15 @@ func (s *Server) htmlManageUsersHandler(w http.ResponseWriter, _ *http.Request) } func (s *Server) htmlDetectedSpamHandler(w http.ResponseWriter, _ *http.Request) { - tmpl, err := template.New("").ParseFS(templateFS, "assets/detected_spam.html", "assets/components/navbar.html") + tmpl, err := template.New("").ParseFS(templateFS, + "assets/detected_spam.html", "assets/components/heads.html", "assets/components/navbar.html") if err != nil { log.Printf("[WARN] can't load template: %v", err) http.Error(w, "Error loading template", http.StatusInternalServerError) return } - ds, err := s.DetectedSpamReader.Read() + ds, err := s.DetectedSpam.Read() if err != nil { log.Printf("[ERROR] Failed to fetch detected spam: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) @@ -568,8 +573,38 @@ func (s *Server) htmlDetectedSpamHandler(w http.ResponseWriter, _ *http.Request) } } +func (s *Server) htmlAddDetectedSpamHandler(w http.ResponseWriter, r *http.Request) { + reportErr := func(err error, _ int) { + w.Header().Set("HX-Retarget", "#error-message") + fmt.Fprintf(w, "
%s
", err) + } + msg := r.FormValue("msg") + + id, err := strconv.ParseInt(r.FormValue("id"), 10, 64) + if err != nil || msg == "" { + log.Printf("[WARN] bad request: %v", err) + reportErr(fmt.Errorf("bad request: %v", err), http.StatusBadRequest) + return + } + + if err := s.SpamFilter.UpdateSpam(msg); err != nil { + log.Printf("[WARN] failed to update spam samples: %v", err) + reportErr(fmt.Errorf("can't update spam samples: %v", err), http.StatusInternalServerError) + return + + } + if err := s.DetectedSpam.SetAddedToSamplesFlag(id); err != nil { + log.Printf("[WARN] failed to update detected spam: %v", err) + reportErr(fmt.Errorf("can't update detected spam: %v", err), http.StatusInternalServerError) + return + } + w.Header().Set("HX-Redirect", "/detected_spam") + _, _ = w.Write([]byte("redirecting...")) +} + func (s *Server) htmlSettingsHandler(w http.ResponseWriter, _ *http.Request) { - tmpl, err := template.New("").ParseFS(templateFS, "assets/settings.html", "assets/components/navbar.html") + tmpl, err := template.New("").ParseFS(templateFS, + "assets/settings.html", "assets/components/heads.html", "assets/components/navbar.html") if err != nil { log.Printf("[WARN] can't load template: %v", err) http.Error(w, "Error loading template", http.StatusInternalServerError) diff --git a/app/webapi/webapi_test.go b/app/webapi/webapi_test.go index e7c23f34..58e76edb 100644 --- a/app/webapi/webapi_test.go +++ b/app/webapi/webapi_test.go @@ -697,7 +697,7 @@ func TestServer_updateApprovedUsersHandler(t *testing.T) { func TestServer_htmlDetectedSpamHandler(t *testing.T) { calls := 0 - ds := &mocks.DetectedSpamReaderMock{ + ds := &mocks.DetectedSpamMock{ ReadFunc: func() ([]storage.DetectedSpamInfo, error) { calls++ if calls > 1 { @@ -719,7 +719,7 @@ func TestServer_htmlDetectedSpamHandler(t *testing.T) { }, nil }, } - server := NewServer(Config{DetectedSpamReader: ds}) + server := NewServer(Config{DetectedSpam: ds}) t.Run("successful rendering", func(t *testing.T) { req, err := http.NewRequest("GET", "/detected_spam", http.NoBody) @@ -747,6 +747,33 @@ func TestServer_htmlDetectedSpamHandler(t *testing.T) { }) } +func TestServer_htmlAddDetectedSpamHandler(t *testing.T) { + ds := &mocks.DetectedSpamMock{ + SetAddedToSamplesFlagFunc: func(id int64) error { + return nil + }, + } + sf := &mocks.SpamFilterMock{ + UpdateSpamFunc: func(msg string) error { + return nil + }, + } + server := NewServer(Config{DetectedSpam: ds, SpamFilter: sf}) + req, err := http.NewRequest("POST", "/detected_spam/add?id=123&msg=blah", http.NoBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(server.htmlAddDetectedSpamHandler) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 1, len(ds.SetAddedToSamplesFlagCalls())) + assert.Equal(t, int64(123), ds.SetAddedToSamplesFlagCalls()[0].ID) + assert.Equal(t, 1, len(sf.UpdateSpamCalls())) + assert.Equal(t, "blah", sf.UpdateSpamCalls()[0].Msg) +} + func TestServer_GenerateRandomPassword(t *testing.T) { res1, err := GenerateRandomPassword(32) require.NoError(t, err)
{{.Timestamp.Format "2006-01-02 15:04:05"}} {{.UserID}} {{.Text}} {{range .Checks}} -
- {{.Name}}: {{.Details}} +
+
+ {{.Name}}: {{.Details}} +
+ {{if and (not .Spam) (not $added) (eq .Name "classifier")}} + + {{end}}
{{end}} +