From f1a51707231e114781a95a71c81961e082dcd926 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 11:47:02 +1000 Subject: [PATCH 01/14] Throttle Slack notifications with notifications record file --- files/walle.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/files/walle.py b/files/walle.py index df5a23a..ec62f0b 100644 --- a/files/walle.py +++ b/files/walle.py @@ -16,6 +16,7 @@ import sys import time import zlib +from datetime import datetime, timedelta from typing import Dict, List import galaxy_jwd @@ -46,6 +47,8 @@ ) logger = logging.getLogger(__name__) GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin") +NOTIFICATION_RECORD_FILE = os.getenv("WALLE_NOTIFICATION_RECORD_FILE", + "/tmp/walle-notifications.txt") def convert_arg_to_byte(mb: str) -> int: @@ -56,6 +59,58 @@ def convert_arg_to_seconds(hours: str) -> float: return float(hours) * 60 * 60 +class NotificationRecord: + """Record of Slack notifications to avoid spamming users.""" + + def __init__(self, record_file: str) -> None: + self.record_file = record_file + self._truncate_records() + + def _get_jwds(self) -> List[str]: + return [ + line[1] for line in self._read_records() + ] + + def _read_records(self) -> List[str]: + with open(self.record_file, "r") as f: + records = [ + line.strip().split('\t') + for line in f.readlines() + if line.strip() + ] + return self._validate(records) + + def _validate(self, records: List[List[str]]) -> List[List[str]]: + try: + for datestr, path in records: + if not isinstance(datestr, str) and isinstance(path, str): + raise ValueError + datetime.strptime(datestr, '%Y-%m-%d') + except ValueError: + logger.warning( + f"Invalid records found in {self.record_file}. The" + " file will be purged. This may result in duplicate Slack" + " notifications.") + self._purge_records() + return [] + return records + + def _write_jwd(self, jwd: str): + with open(self.record_file, "a") as f: + f.write(f"{datetime.now()}\t{jwd}\n") + + def _truncate_records(self): + """Truncate older records.""" + records = self._read_records() + with open(self.record_file, "w") as f: + for datestr, jwd_path in records: + if datetime.strptime(datestr, '%Y-%m-%d') > datetime.now() - timedelta(days=7): + f.write(f"{datestr}\t{jwd_path}\n") + + def posted_for(self, jwd: str) -> bool: + return jwd in self._get_jwds() + + class Severity: def __init__(self, number: int, name: str): self.value = number @@ -87,6 +142,7 @@ def __ge__(self, other) -> bool: VALID_SEVERITIES = (Severity(0, "LOW"), Severity(1, "MEDIUM"), Severity(2, "HIGH")) +notification_record = NotificationRecord(NOTIFICATION_RECORD_FILE) def convert_str_to_severity(test_level: str) -> Severity: @@ -406,6 +462,10 @@ def report_matching_malware(self): ) def post_slack_alert(self): + if notification_record.posted_for(self.jb.jwd): + logger.debug( + "Skipping Slack notification - already posted for this JWD") + return msg = f""" :rotating_light: WALLE: *Malware detected* :rotating_light: From 03afb3f29b71303b8a165ec8f0eca920c156d8a9 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 12:28:16 +1000 Subject: [PATCH 02/14] Touch record file on init --- files/walle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/files/walle.py b/files/walle.py index ec62f0b..904d20d 100644 --- a/files/walle.py +++ b/files/walle.py @@ -63,7 +63,9 @@ class NotificationRecord: """Record of Slack notifications to avoid spamming users.""" def __init__(self, record_file: str) -> None: - self.record_file = record_file + self.record_file = pathlib.Path(record_file) + if not self.record_file.exists(): + self.record_file.touch() self._truncate_records() def _get_jwds(self) -> List[str]: From f7c9d6ed338fffc73c5c96ade07bcb1e9a1e4d17 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 16:24:10 +1000 Subject: [PATCH 03/14] Add test and debug NotificationRecord --- files/test.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++ files/walle.py | 21 ++++++++++++--- 2 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 files/test.py diff --git a/files/test.py b/files/test.py new file mode 100644 index 0000000..86c568a --- /dev/null +++ b/files/test.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import patch +from datetime import datetime, timedelta +import pathlib +import tempfile +from walle import NotificationRecord # Replace 'your_module' with the actual module name + +SLACK_NOTIFY_PERIOD_DAYS = 7 # Define a mock value for testing + + +class TestNotificationRecord(unittest.TestCase): + + def setUp(self): + # Create a temporary file for testing + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.record = NotificationRecord(self.temp_file.name) + + def tearDown(self): + # Clean up the temporary file + pathlib.Path(self.temp_file.name).unlink(missing_ok=True) + + def test_posted_for_new_entry(self): + # Test if a new notification gets posted + jwd = "unique_id_1" + self.assertFalse(self.record.posted_for(jwd), "New entry should return False initially") + self.assertTrue(self.record.posted_for(jwd), "After posting, entry should return True") + + def test_posted_for_existing_entry(self): + # Write a notification to the record file + jwd = "existing_id" + with open(self.temp_file.name, "a") as f: + f.write(f"{datetime.now()}\t{jwd}\n") + self.assertTrue(self.record.posted_for(jwd), "Existing entry should return True") + + @patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS) + def test_truncate_old_records(self): + # Write an old and a recent notification + old_jwd = "old_entry" + recent_jwd = "recent_entry" + old_date = datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS + 1) + recent_date = datetime.now() + + with open(self.temp_file.name, "a") as f: + f.write(f"{old_date.isoformat()}\t{old_jwd}\n") + f.write(f"{recent_date.isoformat()}\t{recent_jwd}\n") + + # Check truncation + self.record._truncate_records() + self.assertFalse(self.record.posted_for(old_jwd), "Old entry should be purged") + self.assertTrue(self.record.posted_for(recent_jwd), "Recent entry should remain") + + def test_purge_invalid_records(self): + # Write an invalid entry to the record file + with open(self.temp_file.name, "w") as f: + f.write("invalid_date\tinvalid_path\n") + + with patch("walle.logger.warning") as mock_warning: + self.record._read_records() + mock_warning.assert_called_once_with( + f"Invalid records found in {self.temp_file.name}. The" + " file will be purged. This may result in duplicate Slack" + " notifications." + ) + + # Check that file is purged + self.assertFalse(self.record._get_jwds(), "Invalid records should be purged") + +if __name__ == "__main__": + unittest.main() diff --git a/files/walle.py b/files/walle.py index 904d20d..6de7980 100644 --- a/files/walle.py +++ b/files/walle.py @@ -34,6 +34,9 @@ If you think your account was deleted due to an error, please contact """ ONLY_ONE_INSTANCE = "The other must be an instance of the Severity class" + +# Number of days before repeating slack alert for the same JWD +SLACK_NOTIFY_PERIOD_DAYS = 7 SLACK_URL = "https://slack.com/api/chat.postMessage" UserId = str @@ -87,7 +90,7 @@ def _validate(self, records: List[List[str]]) -> List[List[str]]: for datestr, path in records: if not isinstance(datestr, str) and isinstance(path, str): raise ValueError - datetime.strptime(datestr, '%Y-%m-%d') + datetime.fromisoformat(datestr) except ValueError: logger.warning( f"Invalid records found in {self.record_file}. The" @@ -101,16 +104,26 @@ def _write_jwd(self, jwd: str): with open(self.record_file, "a") as f: f.write(f"{datetime.now()}\t{jwd}\n") + def _purge_records(self): + self.record_file.unlink() + self.record_file.touch() + def _truncate_records(self): """Truncate older records.""" records = self._read_records() with open(self.record_file, "w") as f: for datestr, jwd_path in records: - if datetime.strptime(datestr, '%Y-%m-%d') > datetime.now() - timedelta(days=7): + if ( + datetime.fromisoformat(datestr) + > datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS) + ): f.write(f"{datestr}\t{jwd_path}\n") def posted_for(self, jwd: str) -> bool: - return jwd in self._get_jwds() + exists = str(jwd) in self._get_jwds() + if not exists: + self._write_jwd(jwd) + return exists class Severity: @@ -464,7 +477,7 @@ def report_matching_malware(self): ) def post_slack_alert(self): - if notification_record.posted_for(self.jb.jwd): + if notification_record.posted_for(self.job.jwd): logger.debug( "Skipping Slack notification - already posted for this JWD") return From 2d2db8c45e81ea2f7fe19af6d599232b5ead2cd6 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 16:34:22 +1000 Subject: [PATCH 04/14] NotificationRecord -> NotificationHistory --- files/test.py | 47 +++++++++++++++++++++-------------------------- files/walle.py | 12 ++++++------ 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/files/test.py b/files/test.py index 86c568a..d8766bf 100644 --- a/files/test.py +++ b/files/test.py @@ -3,67 +3,62 @@ from datetime import datetime, timedelta import pathlib import tempfile -from walle import NotificationRecord # Replace 'your_module' with the actual module name +from walle import NotificationHistory -SLACK_NOTIFY_PERIOD_DAYS = 7 # Define a mock value for testing +SLACK_NOTIFY_PERIOD_DAYS = 7 -class TestNotificationRecord(unittest.TestCase): +class TestNotificationHistory(unittest.TestCase): def setUp(self): - # Create a temporary file for testing self.temp_file = tempfile.NamedTemporaryFile(delete=False) - self.record = NotificationRecord(self.temp_file.name) + self.record = NotificationHistory(self.temp_file.name) def tearDown(self): - # Clean up the temporary file pathlib.Path(self.temp_file.name).unlink(missing_ok=True) - def test_posted_for_new_entry(self): - # Test if a new notification gets posted + def test_contains_new_entry(self): jwd = "unique_id_1" - self.assertFalse(self.record.posted_for(jwd), "New entry should return False initially") - self.assertTrue(self.record.posted_for(jwd), "After posting, entry should return True") + self.assertFalse(self.record.contains(jwd), + "New entry should initially return False") + self.assertTrue(self.record.contains(jwd), + "Duplicate entry should return True") - def test_posted_for_existing_entry(self): - # Write a notification to the record file + def test_contains_existing_entry(self): jwd = "existing_id" with open(self.temp_file.name, "a") as f: f.write(f"{datetime.now()}\t{jwd}\n") - self.assertTrue(self.record.posted_for(jwd), "Existing entry should return True") + self.assertTrue(self.record.contains(jwd), + "Existing entry should return True") @patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS) def test_truncate_old_records(self): - # Write an old and a recent notification old_jwd = "old_entry" recent_jwd = "recent_entry" - old_date = datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS + 1) + old_date = datetime.now() - timedelta( + days=SLACK_NOTIFY_PERIOD_DAYS + 1) recent_date = datetime.now() with open(self.temp_file.name, "a") as f: f.write(f"{old_date.isoformat()}\t{old_jwd}\n") f.write(f"{recent_date.isoformat()}\t{recent_jwd}\n") - # Check truncation self.record._truncate_records() - self.assertFalse(self.record.posted_for(old_jwd), "Old entry should be purged") - self.assertTrue(self.record.posted_for(recent_jwd), "Recent entry should remain") + self.assertFalse(self.record.contains(old_jwd), + "Old entry should be purged") + self.assertTrue(self.record.contains(recent_jwd), + "Recent entry should remain") def test_purge_invalid_records(self): - # Write an invalid entry to the record file with open(self.temp_file.name, "w") as f: f.write("invalid_date\tinvalid_path\n") with patch("walle.logger.warning") as mock_warning: self.record._read_records() - mock_warning.assert_called_once_with( - f"Invalid records found in {self.temp_file.name}. The" - " file will be purged. This may result in duplicate Slack" - " notifications." - ) + mock_warning.assert_called() - # Check that file is purged - self.assertFalse(self.record._get_jwds(), "Invalid records should be purged") + self.assertFalse(self.record._get_jwds(), + "Invalid records should be purged") if __name__ == "__main__": unittest.main() diff --git a/files/walle.py b/files/walle.py index 6de7980..0620296 100644 --- a/files/walle.py +++ b/files/walle.py @@ -50,8 +50,8 @@ ) logger = logging.getLogger(__name__) GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin") -NOTIFICATION_RECORD_FILE = os.getenv("WALLE_NOTIFICATION_RECORD_FILE", - "/tmp/walle-notifications.txt") +NOTIFICATION_HISTORY_FILE = os.getenv("WALLE_NOTIFICATION_HISTORY_FILE", + "/tmp/walle-notifications.txt") def convert_arg_to_byte(mb: str) -> int: @@ -62,7 +62,7 @@ def convert_arg_to_seconds(hours: str) -> float: return float(hours) * 60 * 60 -class NotificationRecord: +class NotificationHistory: """Record of Slack notifications to avoid spamming users.""" def __init__(self, record_file: str) -> None: @@ -119,7 +119,7 @@ def _truncate_records(self): ): f.write(f"{datestr}\t{jwd_path}\n") - def posted_for(self, jwd: str) -> bool: + def contains(self, jwd: str) -> bool: exists = str(jwd) in self._get_jwds() if not exists: self._write_jwd(jwd) @@ -157,7 +157,7 @@ def __ge__(self, other) -> bool: VALID_SEVERITIES = (Severity(0, "LOW"), Severity(1, "MEDIUM"), Severity(2, "HIGH")) -notification_record = NotificationRecord(NOTIFICATION_RECORD_FILE) +notification_history = NotificationHistory(NOTIFICATION_HISTORY_FILE) def convert_str_to_severity(test_level: str) -> Severity: @@ -477,7 +477,7 @@ def report_matching_malware(self): ) def post_slack_alert(self): - if notification_record.posted_for(self.job.jwd): + if notification_history.contains(self.job.jwd): logger.debug( "Skipping Slack notification - already posted for this JWD") return From 727a9746336458b4c9342d60f4d79caeacba736c Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 16:34:29 +1000 Subject: [PATCH 05/14] Add requirements.txt --- files/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 files/requirements.txt diff --git a/files/requirements.txt b/files/requirements.txt new file mode 100644 index 0000000..f17ae1d --- /dev/null +++ b/files/requirements.txt @@ -0,0 +1,3 @@ +psycopg2-binary +requests +pyyaml From 200d9e9bf08a83a9d3462280d20e644cbfb2d7ce Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 16:50:42 +1000 Subject: [PATCH 06/14] Linting --- files/test.py | 24 ++++++++++-------------- files/walle.py | 30 +++++++++++++----------------- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/files/test.py b/files/test.py index d8766bf..e3c410c 100644 --- a/files/test.py +++ b/files/test.py @@ -19,24 +19,22 @@ def tearDown(self): def test_contains_new_entry(self): jwd = "unique_id_1" - self.assertFalse(self.record.contains(jwd), - "New entry should initially return False") - self.assertTrue(self.record.contains(jwd), - "Duplicate entry should return True") + self.assertFalse( + self.record.contains(jwd), "New entry should initially return False" + ) + self.assertTrue(self.record.contains(jwd), "Duplicate entry should return True") def test_contains_existing_entry(self): jwd = "existing_id" with open(self.temp_file.name, "a") as f: f.write(f"{datetime.now()}\t{jwd}\n") - self.assertTrue(self.record.contains(jwd), - "Existing entry should return True") + self.assertTrue(self.record.contains(jwd), "Existing entry should return True") @patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS) def test_truncate_old_records(self): old_jwd = "old_entry" recent_jwd = "recent_entry" - old_date = datetime.now() - timedelta( - days=SLACK_NOTIFY_PERIOD_DAYS + 1) + old_date = datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS + 1) recent_date = datetime.now() with open(self.temp_file.name, "a") as f: @@ -44,10 +42,8 @@ def test_truncate_old_records(self): f.write(f"{recent_date.isoformat()}\t{recent_jwd}\n") self.record._truncate_records() - self.assertFalse(self.record.contains(old_jwd), - "Old entry should be purged") - self.assertTrue(self.record.contains(recent_jwd), - "Recent entry should remain") + self.assertFalse(self.record.contains(old_jwd), "Old entry should be purged") + self.assertTrue(self.record.contains(recent_jwd), "Recent entry should remain") def test_purge_invalid_records(self): with open(self.temp_file.name, "w") as f: @@ -57,8 +53,8 @@ def test_purge_invalid_records(self): self.record._read_records() mock_warning.assert_called() - self.assertFalse(self.record._get_jwds(), - "Invalid records should be purged") + self.assertFalse(self.record._get_jwds(), "Invalid records should be purged") + if __name__ == "__main__": unittest.main() diff --git a/files/walle.py b/files/walle.py index 0620296..d60ba11 100644 --- a/files/walle.py +++ b/files/walle.py @@ -50,8 +50,9 @@ ) logger = logging.getLogger(__name__) GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin") -NOTIFICATION_HISTORY_FILE = os.getenv("WALLE_NOTIFICATION_HISTORY_FILE", - "/tmp/walle-notifications.txt") +NOTIFICATION_HISTORY_FILE = os.getenv( + "WALLE_NOTIFICATION_HISTORY_FILE", "/tmp/walle-notifications.txt" +) def convert_arg_to_byte(mb: str) -> int: @@ -72,16 +73,12 @@ def __init__(self, record_file: str) -> None: self._truncate_records() def _get_jwds(self) -> List[str]: - return [ - line[1] for line in self._read_records() - ] + return [line[1] for line in self._read_records()] def _read_records(self) -> List[str]: with open(self.record_file, "r") as f: records = [ - line.strip().split('\t') - for line in f.readlines() - if line.strip() + line.strip().split("\t") for line in f.readlines() if line.strip() ] return self._validate(records) @@ -95,27 +92,27 @@ def _validate(self, records: List[List[str]]) -> List[List[str]]: logger.warning( f"Invalid records found in {self.record_file}. The" " file will be purged. This may result in duplicate Slack" - " notifications.") + " notifications." + ) self._purge_records() return [] return records - def _write_jwd(self, jwd: str): + def _write_jwd(self, jwd: str) -> None: with open(self.record_file, "a") as f: f.write(f"{datetime.now()}\t{jwd}\n") - def _purge_records(self): + def _purge_records(self) -> None: self.record_file.unlink() self.record_file.touch() - def _truncate_records(self): + def _truncate_records(self) -> None: """Truncate older records.""" records = self._read_records() with open(self.record_file, "w") as f: for datestr, jwd_path in records: - if ( - datetime.fromisoformat(datestr) - > datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS) + if datetime.fromisoformat(datestr) > datetime.now() - timedelta( + days=SLACK_NOTIFY_PERIOD_DAYS ): f.write(f"{datestr}\t{jwd_path}\n") @@ -478,8 +475,7 @@ def report_matching_malware(self): def post_slack_alert(self): if notification_history.contains(self.job.jwd): - logger.debug( - "Skipping Slack notification - already posted for this JWD") + logger.debug("Skipping Slack notification - already posted for this JWD") return msg = f""" :rotating_light: WALLE: *Malware detected* :rotating_light: From cd1b6f3fca5f2472901993c486701892170e175d Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 16:54:07 +1000 Subject: [PATCH 07/14] Fix type hinting --- files/walle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/files/walle.py b/files/walle.py index d60ba11..6a31783 100644 --- a/files/walle.py +++ b/files/walle.py @@ -17,7 +17,7 @@ import time import zlib from datetime import datetime, timedelta -from typing import Dict, List +from typing import Dict, List, Union import galaxy_jwd import requests @@ -75,7 +75,7 @@ def __init__(self, record_file: str) -> None: def _get_jwds(self) -> List[str]: return [line[1] for line in self._read_records()] - def _read_records(self) -> List[str]: + def _read_records(self) -> List[List[str]]: with open(self.record_file, "r") as f: records = [ line.strip().split("\t") for line in f.readlines() if line.strip() @@ -116,7 +116,7 @@ def _truncate_records(self) -> None: ): f.write(f"{datestr}\t{jwd_path}\n") - def contains(self, jwd: str) -> bool: + def contains(self, jwd: Union[pathlib.Path, str]) -> bool: exists = str(jwd) in self._get_jwds() if not exists: self._write_jwd(jwd) From b2af809696a031050150542b7bcc3d91579edd95 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 18:37:03 +1000 Subject: [PATCH 08/14] Fix type hinting --- files/walle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/files/walle.py b/files/walle.py index 6a31783..52e2609 100644 --- a/files/walle.py +++ b/files/walle.py @@ -117,7 +117,8 @@ def _truncate_records(self) -> None: f.write(f"{datestr}\t{jwd_path}\n") def contains(self, jwd: Union[pathlib.Path, str]) -> bool: - exists = str(jwd) in self._get_jwds() + jwd = str(jwd) + exists = jwd in self._get_jwds() if not exists: self._write_jwd(jwd) return exists From f0e11a3a3ced87cee42d5a91356a049eb368212a Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Tue, 12 Nov 2024 18:38:48 +1000 Subject: [PATCH 09/14] Fix linting --- files/test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/files/test.py b/files/test.py index e3c410c..13496e2 100644 --- a/files/test.py +++ b/files/test.py @@ -9,7 +9,6 @@ class TestNotificationHistory(unittest.TestCase): - def setUp(self): self.temp_file = tempfile.NamedTemporaryFile(delete=False) self.record = NotificationHistory(self.temp_file.name) From 46b884a7510bf1162c666b952b56a84527c74fb8 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Thu, 14 Nov 2024 06:33:03 +1000 Subject: [PATCH 10/14] Add gitignore --- files/.gitignore | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 files/.gitignore diff --git a/files/.gitignore b/files/.gitignore new file mode 100644 index 0000000..04dd13f --- /dev/null +++ b/files/.gitignore @@ -0,0 +1,8 @@ +# Virtual Environments +venv/ +.venv/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class From ee90b0b0c1bfb60ba5ff787b35f8b20875a65028 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Thu, 14 Nov 2024 06:34:13 +1000 Subject: [PATCH 11/14] Refactor for better type hinting --- files/test.py | 3 +-- files/walle.py | 47 ++++++++++++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/files/test.py b/files/test.py index 13496e2..ae8f770 100644 --- a/files/test.py +++ b/files/test.py @@ -25,8 +25,7 @@ def test_contains_new_entry(self): def test_contains_existing_entry(self): jwd = "existing_id" - with open(self.temp_file.name, "a") as f: - f.write(f"{datetime.now()}\t{jwd}\n") + self.record._write_record(jwd) self.assertTrue(self.record.contains(jwd), "Existing entry should return True") @patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS) diff --git a/files/walle.py b/files/walle.py index 52e2609..e4cca38 100644 --- a/files/walle.py +++ b/files/walle.py @@ -16,6 +16,7 @@ import sys import time import zlib +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Dict, List, Union @@ -63,6 +64,19 @@ def convert_arg_to_seconds(hours: str) -> float: return float(hours) * 60 * 60 +@dataclass +class Record: + date: str + jwd: Union[str, pathlib.Path] + + def __post_init__(self): + if not ( + isinstance(self.date, str) and isinstance(self.jwd, (str, pathlib.Path)) + ): + raise ValueError + datetime.fromisoformat(self.date) # will raise ValueError if invalid + + class NotificationHistory: """Record of Slack notifications to avoid spamming users.""" @@ -72,22 +86,17 @@ def __init__(self, record_file: str) -> None: self.record_file.touch() self._truncate_records() - def _get_jwds(self) -> List[str]: - return [line[1] for line in self._read_records()] - - def _read_records(self) -> List[List[str]]: - with open(self.record_file, "r") as f: - records = [ - line.strip().split("\t") for line in f.readlines() if line.strip() - ] - return self._validate(records) + def _get_jwds(self) -> Record: + return [record.jwd for record in self._read_records()] - def _validate(self, records: List[List[str]]) -> List[List[str]]: + def _read_records(self) -> List[Record]: try: - for datestr, path in records: - if not isinstance(datestr, str) and isinstance(path, str): - raise ValueError - datetime.fromisoformat(datestr) + with open(self.record_file, "r") as f: + records = [ + Record(*line.strip().split("\t")) + for line in f.readlines() + if line.strip() + ] except ValueError: logger.warning( f"Invalid records found in {self.record_file}. The" @@ -98,7 +107,7 @@ def _validate(self, records: List[List[str]]) -> List[List[str]]: return [] return records - def _write_jwd(self, jwd: str) -> None: + def _write_record(self, jwd: str) -> None: with open(self.record_file, "a") as f: f.write(f"{datetime.now()}\t{jwd}\n") @@ -110,17 +119,17 @@ def _truncate_records(self) -> None: """Truncate older records.""" records = self._read_records() with open(self.record_file, "w") as f: - for datestr, jwd_path in records: - if datetime.fromisoformat(datestr) > datetime.now() - timedelta( + for record in records: + if datetime.fromisoformat(record.date) > datetime.now() - timedelta( days=SLACK_NOTIFY_PERIOD_DAYS ): - f.write(f"{datestr}\t{jwd_path}\n") + f.write(f"{record.date}\t{record.jwd}\n") def contains(self, jwd: Union[pathlib.Path, str]) -> bool: jwd = str(jwd) exists = jwd in self._get_jwds() if not exists: - self._write_jwd(jwd) + self._write_record(jwd) return exists From 34f3d851ee85268c3d534a52ac7afd847cd8add3 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Thu, 14 Nov 2024 06:37:54 +1000 Subject: [PATCH 12/14] Fix type hint --- files/walle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/files/walle.py b/files/walle.py index e4cca38..bb05a46 100644 --- a/files/walle.py +++ b/files/walle.py @@ -86,7 +86,7 @@ def __init__(self, record_file: str) -> None: self.record_file.touch() self._truncate_records() - def _get_jwds(self) -> Record: + def _get_jwds(self) -> List[str]: return [record.jwd for record in self._read_records()] def _read_records(self) -> List[Record]: From 03e088bcf606ba0ad764dd1205e668e3f1060879 Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Thu, 14 Nov 2024 06:40:20 +1000 Subject: [PATCH 13/14] Fix type hint --- files/walle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/files/walle.py b/files/walle.py index bb05a46..f2d032d 100644 --- a/files/walle.py +++ b/files/walle.py @@ -67,13 +67,14 @@ def convert_arg_to_seconds(hours: str) -> float: @dataclass class Record: date: str - jwd: Union[str, pathlib.Path] + jwd: str def __post_init__(self): if not ( isinstance(self.date, str) and isinstance(self.jwd, (str, pathlib.Path)) ): raise ValueError + self.jwd = str(self.jwd) datetime.fromisoformat(self.date) # will raise ValueError if invalid From e0f2f4de8afa8bdb7c663452611950fa22ee739c Mon Sep 17 00:00:00 2001 From: Cameron Hyde Date: Thu, 14 Nov 2024 06:50:52 +1000 Subject: [PATCH 14/14] Soft pin requirements.txt --- files/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/files/requirements.txt b/files/requirements.txt index f17ae1d..b9e86f8 100644 --- a/files/requirements.txt +++ b/files/requirements.txt @@ -1,3 +1,3 @@ -psycopg2-binary -requests -pyyaml +psycopg2-binary>=2.9,<3.0 +PyYAML>=6.0,<7.0 +requests>=2.32,<3.0