diff --git a/app/server/server.go b/app/server/server.go index dc6c41c3..f9a04ca6 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -39,7 +39,7 @@ type SpamWeb struct { unbanned struct { sync.RWMutex - users map[int64]time.Time + users map[string]time.Time // key is the combination of userID and message hash } } @@ -68,7 +68,7 @@ type Detector interface { // NewSpamWeb creates new server func NewSpamWeb(tbAPI TbAPI, detector Detector, params Config) (*SpamWeb, error) { res := SpamWeb{Config: params, TbAPI: tbAPI, detector: detector} - res.unbanned.users = make(map[int64]time.Time) + res.unbanned.users = make(map[string]time.Time) chatID, err := res.getChatID(params.TgGroup) if err != nil { return nil, fmt.Errorf("can't get chat ID for %s: %w", params.TgGroup, err) @@ -83,7 +83,7 @@ func (s *SpamWeb) Run(ctx context.Context) error { router.Use(rest.Recoverer(lgr.Default())) router.Use(middleware.Throttle(1000), middleware.Timeout(60*time.Second)) router.Use(rest.AppInfo("tg-spam", "umputun", s.Version), rest.Ping) - router.Use(tollbooth_chi.LimitHandler(tollbooth.NewLimiter(10, nil))) + router.Use(tollbooth_chi.LimitHandler(tollbooth.NewLimiter(50, nil))) router.Get("/unban", s.unbanHandler) @@ -115,6 +115,8 @@ func (s *SpamWeb) unbanHandler(w http.ResponseWriter, r *http.Request) { id := r.URL.Query().Get("user") token := r.URL.Query().Get("token") userID, err := strconv.ParseInt(id, 10, 64) + msg := r.URL.Query().Get("msg") + if err != nil { log.Printf("[WARN] failed to get user ID for %q, %v", id, err) resp := htmlResponse{ @@ -145,7 +147,7 @@ func (s *SpamWeb) unbanHandler(w http.ResponseWriter, r *http.Request) { isAlreadyUnbanned, tsPrevUnban := func() (bool, time.Time) { s.unbanned.RLock() defer s.unbanned.RUnlock() - ts, ok := s.unbanned.users[userID] + ts, ok := s.unbanned.users[s.unbanKey(userID, msg)] return ok, ts }() @@ -197,7 +199,7 @@ func (s *SpamWeb) unbanHandler(w http.ResponseWriter, r *http.Request) { } s.unbanned.Lock() - s.unbanned.users[userID] = time.Now() + s.unbanned.users[s.unbanKey(userID, msg)] = time.Now() s.unbanned.Unlock() s.sendHTML(w, resp) @@ -308,6 +310,11 @@ func (s *SpamWeb) decompressString(compressed string) (string, error) { return string(decoded), nil } +// unbanKey returns key for unbanned users map, converts the key to sha256 +func (s *SpamWeb) unbanKey(id int64, msg string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%d::%s", id, msg)))) +} + var msgTemplate = ` diff --git a/app/server/server_test.go b/app/server/server_test.go index 3469dedd..1c3661ae 100644 --- a/app/server/server_test.go +++ b/app/server/server_test.go @@ -2,7 +2,9 @@ package server import ( "context" + "crypto/sha256" "errors" + "fmt" "io" "net/http" "testing" @@ -231,6 +233,40 @@ func TestSpamWeb_Run(t *testing.T) { assert.Contains(t, string(body), "user 1239 already unbanned") }) + + t.Run("unban allowed with second attempt, but different msg", func(t *testing.T) { + mockAPI.ResetCalls() + mockDetector.ResetCalls() + req, err := http.NewRequest("GET", + "http://localhost:9900/unban?user=1239&token=e2b5356cfe79210553b4a0bc89310ea5961dc76e86046b07c61e479c9835623c&msg=123", http.NoBody) + require.NoError(t, err) + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 1, len(mockAPI.RequestCalls())) + assert.Equal(t, int64(10), mockAPI.RequestCalls()[0].C.(tbapi.UnbanChatMemberConfig).ChatID) + assert.Equal(t, int64(1239), mockAPI.RequestCalls()[0].C.(tbapi.UnbanChatMemberConfig).UserID) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + t.Logf("body: %s", body) + assert.Contains(t, string(body), "Success") + assert.Equal(t, "text/html", resp.Header.Get("Content-Type")) + assert.Equal(t, 0, len(mockDetector.UpdateHamCalls()), "no message, nothing to update") + + // second attempt + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + t.Logf("body: %s", body) + assert.Contains(t, string(body), "user 1239 already unbanned") + + }) + t.Run("unban allowed, matched token with msg", func(t *testing.T) { mockAPI.ResetCalls() mockDetector.ResetCalls() @@ -372,3 +408,39 @@ func TestSpamWeb_CompressAndDecompressString(t *testing.T) { }) } } + +func TestUnbanKey(t *testing.T) { + s := SpamWeb{} + + t.Run("GeneratesCorrectKey", func(t *testing.T) { + id := int64(123) + msg := "test message" + expectedKey := fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%d::%s", id, msg)))) + + key := s.unbanKey(id, msg) + + assert.Equal(t, expectedKey, key) + }) + + t.Run("GeneratesDifferentKeysForDifferentIDs", func(t *testing.T) { + id1 := int64(123) + id2 := int64(456) + msg := "test message" + + key1 := s.unbanKey(id1, msg) + key2 := s.unbanKey(id2, msg) + + assert.NotEqual(t, key1, key2) + }) + + t.Run("GeneratesDifferentKeysForDifferentMessages", func(t *testing.T) { + id := int64(123) + msg1 := "test message 1" + msg2 := "test message 2" + + key1 := s.unbanKey(id, msg1) + key2 := s.unbanKey(id, msg2) + + assert.NotEqual(t, key1, key2) + }) +}