diff --git a/files/.gitignore b/files/.gitignore new file mode 100644 index 0000000..04dd13f --- /dev/null +++ b/files/.gitignore @@ -0,0 +1,8 @@ +# Virtual Environments +venv/ +.venv/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class diff --git a/files/requirements.txt b/files/requirements.txt new file mode 100644 index 0000000..b9e86f8 --- /dev/null +++ b/files/requirements.txt @@ -0,0 +1,3 @@ +psycopg2-binary>=2.9,<3.0 +PyYAML>=6.0,<7.0 +requests>=2.32,<3.0 diff --git a/files/test.py b/files/test.py new file mode 100644 index 0000000..ae8f770 --- /dev/null +++ b/files/test.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import patch +from datetime import datetime, timedelta +import pathlib +import tempfile +from walle import NotificationHistory + +SLACK_NOTIFY_PERIOD_DAYS = 7 + + +class TestNotificationHistory(unittest.TestCase): + def setUp(self): + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.record = NotificationHistory(self.temp_file.name) + + def tearDown(self): + pathlib.Path(self.temp_file.name).unlink(missing_ok=True) + + def test_contains_new_entry(self): + jwd = "unique_id_1" + self.assertFalse( + self.record.contains(jwd), "New entry should initially return False" + ) + self.assertTrue(self.record.contains(jwd), "Duplicate entry should return True") + + def test_contains_existing_entry(self): + jwd = "existing_id" + self.record._write_record(jwd) + self.assertTrue(self.record.contains(jwd), "Existing entry should return True") + + @patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS) + def test_truncate_old_records(self): + old_jwd = "old_entry" + recent_jwd = "recent_entry" + old_date = datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS + 1) + recent_date = datetime.now() + + with open(self.temp_file.name, "a") as f: + f.write(f"{old_date.isoformat()}\t{old_jwd}\n") + f.write(f"{recent_date.isoformat()}\t{recent_jwd}\n") + + self.record._truncate_records() + self.assertFalse(self.record.contains(old_jwd), "Old entry should be purged") + self.assertTrue(self.record.contains(recent_jwd), "Recent entry should remain") + + def test_purge_invalid_records(self): + with open(self.temp_file.name, "w") as f: + f.write("invalid_date\tinvalid_path\n") + + with patch("walle.logger.warning") as mock_warning: + self.record._read_records() + mock_warning.assert_called() + + self.assertFalse(self.record._get_jwds(), "Invalid records should be purged") + + +if __name__ == "__main__": + unittest.main() diff --git a/files/walle.py b/files/walle.py index df5a23a..f2d032d 100644 --- a/files/walle.py +++ b/files/walle.py @@ -16,7 +16,9 @@ import sys import time import zlib -from typing import Dict, List +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Dict, List, Union import galaxy_jwd import requests @@ -33,6 +35,9 @@ If you think your account was deleted due to an error, please contact """ ONLY_ONE_INSTANCE = "The other must be an instance of the Severity class" + +# Number of days before repeating slack alert for the same JWD +SLACK_NOTIFY_PERIOD_DAYS = 7 SLACK_URL = "https://slack.com/api/chat.postMessage" UserId = str @@ -46,6 +51,9 @@ ) logger = logging.getLogger(__name__) GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin") +NOTIFICATION_HISTORY_FILE = os.getenv( + "WALLE_NOTIFICATION_HISTORY_FILE", "/tmp/walle-notifications.txt" +) def convert_arg_to_byte(mb: str) -> int: @@ -56,6 +64,76 @@ def convert_arg_to_seconds(hours: str) -> float: return float(hours) * 60 * 60 +@dataclass +class Record: + date: str + jwd: str + + def __post_init__(self): + if not ( + isinstance(self.date, str) and isinstance(self.jwd, (str, pathlib.Path)) + ): + raise ValueError + self.jwd = str(self.jwd) + datetime.fromisoformat(self.date) # will raise ValueError if invalid + + +class NotificationHistory: + """Record of Slack notifications to avoid spamming users.""" + + def __init__(self, record_file: str) -> None: + self.record_file = pathlib.Path(record_file) + if not self.record_file.exists(): + self.record_file.touch() + self._truncate_records() + + def _get_jwds(self) -> List[str]: + return [record.jwd for record in self._read_records()] + + def _read_records(self) -> List[Record]: + try: + with open(self.record_file, "r") as f: + records = [ + Record(*line.strip().split("\t")) + for line in f.readlines() + if line.strip() + ] + except ValueError: + logger.warning( + f"Invalid records found in {self.record_file}. The" + " file will be purged. This may result in duplicate Slack" + " notifications." + ) + self._purge_records() + return [] + return records + + def _write_record(self, jwd: str) -> None: + with open(self.record_file, "a") as f: + f.write(f"{datetime.now()}\t{jwd}\n") + + def _purge_records(self) -> None: + self.record_file.unlink() + self.record_file.touch() + + def _truncate_records(self) -> None: + """Truncate older records.""" + records = self._read_records() + with open(self.record_file, "w") as f: + for record in records: + if datetime.fromisoformat(record.date) > datetime.now() - timedelta( + days=SLACK_NOTIFY_PERIOD_DAYS + ): + f.write(f"{record.date}\t{record.jwd}\n") + + def contains(self, jwd: Union[pathlib.Path, str]) -> bool: + jwd = str(jwd) + exists = jwd in self._get_jwds() + if not exists: + self._write_record(jwd) + return exists + + class Severity: def __init__(self, number: int, name: str): self.value = number @@ -87,6 +165,7 @@ def __ge__(self, other) -> bool: VALID_SEVERITIES = (Severity(0, "LOW"), Severity(1, "MEDIUM"), Severity(2, "HIGH")) +notification_history = NotificationHistory(NOTIFICATION_HISTORY_FILE) def convert_str_to_severity(test_level: str) -> Severity: @@ -406,6 +485,9 @@ def report_matching_malware(self): ) def post_slack_alert(self): + if notification_history.contains(self.job.jwd): + logger.debug("Skipping Slack notification - already posted for this JWD") + return msg = f""" :rotating_light: WALLE: *Malware detected* :rotating_light: