From 5a0c5af07214ab84cf019bd90ce95824750a9d05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristina=20Mu=C3=B1oz?= Date: Thu, 6 Feb 2020 09:10:55 -0800 Subject: [PATCH 1/2] Misc cleanup and TODOs on malware checks. - Change backfill function to invoke `IMalwareCheckService` interface - Add support for `kwargs to `IMalwareCheckService` interface - Rename variable from reserved word `file` to `release_file` - Add `FatalCheckException` for non-retryable exceptions - Replace `MALWARE_CHECK_BACKEND` in dev/environment --- dev/environment | 3 +- tests/common/checks/hooked.py | 3 +- .../checks/setup_patterns/test_check.py | 5 +- tests/unit/malware/test_services.py | 8 ++- tests/unit/malware/test_tasks.py | 55 ++++++++++++++++--- .../malware/checks/setup_patterns/check.py | 24 ++++---- warehouse/malware/checks/utils.py | 4 ++ warehouse/malware/interfaces.py | 2 +- warehouse/malware/services.py | 8 +-- warehouse/malware/tasks.py | 12 +++- 10 files changed, 90 insertions(+), 34 deletions(-) diff --git a/dev/environment b/dev/environment index ec7eeae6d2f3..5d9fe6cc5af6 100644 --- a/dev/environment +++ b/dev/environment @@ -29,8 +29,7 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService -#TODO: change this to PrinterMalwareCheckService before deploy -MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService +MALWARE_CHECK_BACKEND=warehouse.malware.services.PrinterMalwareCheckService METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog diff --git a/tests/common/checks/hooked.py b/tests/common/checks/hooked.py index 8a3e16a3cbf9..549ec3a23992 100644 --- a/tests/common/checks/hooked.py +++ b/tests/common/checks/hooked.py @@ -11,6 +11,7 @@ # limitations under the License. from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.checks.utils import FatalCheckException from warehouse.malware.models import VerdictClassification, VerdictConfidence @@ -29,7 +30,7 @@ def __init__(self, db): def scan(self, **kwargs): file_id = kwargs.get("obj_id") if file_id is None: - return + raise FatalCheckException("Missing required kwarg `obj_id`") self.add_verdict( file_id=file_id, diff --git a/tests/unit/malware/checks/setup_patterns/test_check.py b/tests/unit/malware/checks/setup_patterns/test_check.py index 0dbd5c19d06f..c25556cf988d 100644 --- a/tests/unit/malware/checks/setup_patterns/test_check.py +++ b/tests/unit/malware/checks/setup_patterns/test_check.py @@ -43,9 +43,8 @@ def test_scan_missing_kwargs(db_session, obj, file_url): name="SetupPatternCheck", state=MalwareCheckState.Enabled ) check = c.SetupPatternCheck(db_session) - check.scan(obj=obj, file_url=file_url) - - assert check._verdicts == [] + with pytest.raises(c.FatalCheckException): + check.scan(obj=obj, file_url=file_url) def test_scan_non_sdist(db_session): diff --git a/tests/unit/malware/test_services.py b/tests/unit/malware/test_services.py index 7a9cb636f720..1b950814c585 100644 --- a/tests/unit/malware/test_services.py +++ b/tests/unit/malware/test_services.py @@ -11,6 +11,7 @@ # limitations under the License. import pretend +import pytest from zope.interface.verify import verifyClass @@ -31,13 +32,14 @@ def test_create_service(self): service = PrinterMalwareCheckService.create_service(None, request) assert service.executor == print - def test_run_checks(self, capfd): + @pytest.mark.parametrize(("kwargs"), [{}, {"manually_triggered": True}]) + def test_run_checks(self, capfd, kwargs): request = pretend.stub() service = PrinterMalwareCheckService.create_service(None, request) checks = ["one", "two", "three"] - service.run_checks(checks) + service.run_checks(checks, **kwargs) out, err = capfd.readouterr() - assert out == "one\ntwo\nthree\n" + assert out == "".join(["%s %s\n" % (check, kwargs) for check in checks]) class TestDatabaseMalwareService: diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 6a333c31b018..6ddd59dc4e71 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -16,6 +16,7 @@ from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict +from warehouse.malware.services import PrinterMalwareCheckService from ...common import checks as test_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory @@ -101,6 +102,28 @@ def test_missing_check(self, db_request, monkeypatch): task, db_request, "DoesNotExistCheck", ) + def test_missing_obj_id(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + task = pretend.stub() + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.Enabled + ) + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)), + ) + + tasks.run_check(task, request, "ExampleHookedCheck") + + assert request.log.error.calls == [ + pretend.call( + "Fatal exception: ExampleHookedCheck: Missing required kwarg `obj_id`" + ) + ] + def test_retry(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) exc = Exception("Scan failed") @@ -145,33 +168,47 @@ def test_invalid_check_name(self, db_request, monkeypatch): @pytest.mark.parametrize( ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)] ) - def test_run(self, db_session, num_objects, num_runs, monkeypatch): + def test_run(self, db_session, capfd, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) + + ids = [] for i in range(num_objects): - FileFactory.create() + ids.append(FileFactory.create().id) MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) - enqueue_recorder = pretend.stub( - delay=pretend.call_recorder(lambda *a, **kw: None) - ) - task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder) - request = pretend.stub( db=db_session, log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), - task=task, + find_service_factory=pretend.call_recorder( + lambda interface: PrinterMalwareCheckService.create_service + ), ) + task = pretend.stub() + tasks.backfill(task, request, "ExampleHookedCheck", num_runs) assert request.log.info.calls == [ pretend.call("Running backfill on %d Files." % num_runs) ] - assert len(enqueue_recorder.delay.calls) == num_runs + assert request.find_service_factory.calls == [ + pretend.call(tasks.IMalwareCheckService) + ] + + out, err = capfd.readouterr() + num_output_lines = 0 + for file_id in ids: + logged_output = "ExampleHookedCheck:%s %s\n" % ( + file_id, + {"manually_triggered": True}, + ) + num_output_lines += 1 if logged_output in out else 0 + + assert num_output_lines == num_runs class TestSyncChecks: diff --git a/warehouse/malware/checks/setup_patterns/check.py b/warehouse/malware/checks/setup_patterns/check.py index c2127bb292db..9cb74061d0e1 100644 --- a/warehouse/malware/checks/setup_patterns/check.py +++ b/warehouse/malware/checks/setup_patterns/check.py @@ -17,7 +17,11 @@ import yara from warehouse.malware.checks.base import MalwareCheckBase -from warehouse.malware.checks.utils import extract_file_content, fetch_url_content +from warehouse.malware.checks.utils import ( + FatalCheckException, + extract_file_content, + fetch_url_content, +) from warehouse.malware.models import VerdictClassification, VerdictConfidence @@ -45,14 +49,14 @@ def _load_yara_rules(self): return yara.compile(filepath=self._yara_rule_file) def scan(self, **kwargs): - file = kwargs.get("obj") + release_file = kwargs.get("obj") file_url = kwargs.get("file_url") - if file is None or file_url is None: - # TODO: Maybe raise here, since the absence of these - # arguments is a use/user error. - return + if release_file is None or file_url is None: + raise FatalCheckException( + "Release file or file url is None, indicating user error." + ) - if file.packagetype != "sdist": + if release_file.packagetype != "sdist": # Per PEP 491: bdists do not contain setup.py. # This check only scans dists that contain setup.py, so # we have nothing to perform. @@ -62,7 +66,7 @@ def scan(self, **kwargs): setup_py_contents = extract_file_content(archive_stream, "setup.py") if setup_py_contents is None: self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=VerdictClassification.Indeterminate, confidence=VerdictConfidence.High, message="sdist does not contain a suitable setup.py for analysis", @@ -92,7 +96,7 @@ def scan(self, **kwargs): } self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=classification, confidence=confidence, message=message, @@ -101,7 +105,7 @@ def scan(self, **kwargs): else: # No matches? Report a low-confidence benign verdict. self.add_verdict( - file_id=file.id, + file_id=release_file.id, classification=VerdictClassification.Benign, confidence=VerdictConfidence.Low, message="No malicious patterns found in setup.py", diff --git a/warehouse/malware/checks/utils.py b/warehouse/malware/checks/utils.py index 5ddda01ccc7c..d3b4acc7d908 100644 --- a/warehouse/malware/checks/utils.py +++ b/warehouse/malware/checks/utils.py @@ -78,3 +78,7 @@ def extract_file_content(archive_stream, file_path): return None except tarfile.TarError: return None + + +class FatalCheckException(Exception): + pass diff --git a/warehouse/malware/interfaces.py b/warehouse/malware/interfaces.py index f179aa374d55..482907735f33 100644 --- a/warehouse/malware/interfaces.py +++ b/warehouse/malware/interfaces.py @@ -20,7 +20,7 @@ def create_service(context, request): created for. """ - def run_checks(checks): + def run_checks(checks, **kwargs): """ Run a given set of Checks """ diff --git a/warehouse/malware/services.py b/warehouse/malware/services.py index f2f454b964e2..ccb36723b345 100644 --- a/warehouse/malware/services.py +++ b/warehouse/malware/services.py @@ -25,9 +25,9 @@ def __init__(self, executor): def create_service(cls, context, request): return cls(print) - def run_checks(self, checks): + def run_checks(self, checks, **kwargs): for check in checks: - self.executor(check) + self.executor(check, kwargs) @implementer(IMalwareCheckService) @@ -39,7 +39,7 @@ def __init__(self, executor): def create_service(cls, context, request): return cls(request.task(run_check).delay) - def run_checks(self, checks): + def run_checks(self, checks, **kwargs): for check_info in checks: check_name, obj_id = check_info.split(":") - self.executor(check_name, obj_id) + self.executor(check_name, obj_id, **kwargs) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 1fbe46b7a638..7355df973f06 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -17,6 +17,8 @@ import warehouse.malware.checks as checks import warehouse.packaging.models as packaging_models +from warehouse.malware.checks.utils import FatalCheckException +from warehouse.malware.interfaces import IMalwareCheckService from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -47,6 +49,9 @@ def run_check(task, request, check_name, obj_id=None, manually_triggered=False): try: check.run(**kwargs) + except FatalCheckException as exc: + request.log.error("Fatal exception: %s: %s" % (check_name, str(exc))) + return except Exception as exc: request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) @@ -63,8 +68,13 @@ def backfill(task, request, check_name, num_objects): request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) + runs = set() for (elem_id,) in query: - request.task(run_check).delay(check_name, elem_id, manually_triggered=True) + runs.update([f"{check_name}:{elem_id}"]) + + malware_check_service = request.find_service_factory(IMalwareCheckService) + malware_check = malware_check_service(None, request) + malware_check.run_checks(runs, manually_triggered=True) @task(bind=True, ignore_result=True, acks_late=True) From b0707784cabaf89c69506f2599e998abd868cf18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristina=20Mu=C3=B1oz?= Date: Fri, 7 Feb 2020 13:05:00 -0800 Subject: [PATCH 2/2] Make `IMalwareService` the entrypoint for `run_check` - Add `run_scheduled_check` task that invokes this interface. - Remove useless utility method - Move `FatalCheckException` into warehouse/malware/errors.py. --- tests/common/checks/hooked.py | 2 +- tests/unit/admin/views/test_checks.py | 4 +-- tests/unit/malware/test_init.py | 17 +++++----- tests/unit/malware/test_services.py | 32 ++++++++++++++++--- tests/unit/malware/test_tasks.py | 31 ++++++++++++++++++ warehouse/admin/views/checks.py | 4 +-- warehouse/malware/__init__.py | 11 ++++--- .../malware/checks/setup_patterns/check.py | 7 ++-- warehouse/malware/checks/utils.py | 4 --- warehouse/malware/errors.py | 15 +++++++++ warehouse/malware/services.py | 10 ++++-- warehouse/malware/tasks.py | 9 +++++- warehouse/malware/utils.py | 14 +------- 13 files changed, 113 insertions(+), 47 deletions(-) create mode 100644 warehouse/malware/errors.py diff --git a/tests/common/checks/hooked.py b/tests/common/checks/hooked.py index 549ec3a23992..2aa72a1bb8ae 100644 --- a/tests/common/checks/hooked.py +++ b/tests/common/checks/hooked.py @@ -11,7 +11,7 @@ # limitations under the License. from warehouse.malware.checks.base import MalwareCheckBase -from warehouse.malware.checks.utils import FatalCheckException +from warehouse.malware.errors import FatalCheckException from warehouse.malware.models import VerdictClassification, VerdictConfidence diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 3edf31232e91..8aafc0a68319 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -17,7 +17,7 @@ from warehouse.admin.views import checks as views from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.tasks import backfill, run_check +from warehouse.malware.tasks import backfill, run_scheduled_check from ....common.db.malware import MalwareCheckFactory @@ -208,7 +208,7 @@ def test_success(self, db_request, check_type): assert db_request.session.flash.calls == [ pretend.call("Running %s now!" % check.name, queue="success",) ] - assert db_request.task.calls == [pretend.call(run_check)] + assert db_request.task.calls == [pretend.call(run_scheduled_check)] assert backfill_recorder.delay.calls == [ pretend.call(check.name, manually_triggered=True) ] diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index 32628118398d..fc642a23cd41 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -17,9 +17,8 @@ from celery.schedules import crontab from warehouse import malware -from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService -from warehouse.malware.tasks import run_check +from warehouse.malware.tasks import run_scheduled_check from ...common import checks as test_checks from ...common.db.accounts import UserFactory @@ -30,7 +29,7 @@ def test_determine_malware_checks_no_checks(monkeypatch, db_request): def get_enabled_hooked_checks(session): return defaultdict(list) - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -49,7 +48,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -68,7 +67,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) user = UserFactory.create() @@ -85,7 +84,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -105,7 +104,7 @@ def get_enabled_hooked_checks(session): result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) + monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -179,6 +178,8 @@ def test_includeme(monkeypatch): assert config.add_periodic_task.calls == [ pretend.call( - crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",) + crontab(minute="0", hour="*/8"), + run_scheduled_check, + args=("ExampleScheduledCheck",), ) ] diff --git a/tests/unit/malware/test_services.py b/tests/unit/malware/test_services.py index 1b950814c585..3348bb894d81 100644 --- a/tests/unit/malware/test_services.py +++ b/tests/unit/malware/test_services.py @@ -52,12 +52,36 @@ def test_create_service(self, db_request): service = DatabaseMalwareCheckService.create_service(None, db_request) assert service.executor == db_request.task(run_check).delay - def test_run_checks(self, db_request): - _delay = pretend.call_recorder(lambda *args: None) + def test_run_hooked_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) db_request.task = lambda x: pretend.stub(delay=_delay) service = DatabaseMalwareCheckService.create_service(None, db_request) - checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"] + checks = [ + "MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187", + "AnotherCheck:44f57b0e-c5b0-47c5-8713-341cf392efe2", + "FinalCheck:e8518a15-8f01-430e-8f5b-87644007c9c0", + ] service.run_checks(checks) assert _delay.calls == [ - pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187") + pretend.call("MyTestCheck", obj_id="ba70267f-fabf-496f-9ac2-d237a983b187"), + pretend.call("AnotherCheck", obj_id="44f57b0e-c5b0-47c5-8713-341cf392efe2"), + pretend.call("FinalCheck", obj_id="e8518a15-8f01-430e-8f5b-87644007c9c0"), + ] + + def test_run_scheduled_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = ["MyTestScheduledCheck"] + service.run_checks(checks) + assert _delay.calls == [pretend.call("MyTestScheduledCheck")] + + def test_run_triggered_check(self, db_request): + _delay = pretend.call_recorder(lambda *args, **kwargs: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = ["MyTriggeredCheck"] + service.run_checks(checks, manually_triggered=True) + assert _delay.calls == [ + pretend.call("MyTriggeredCheck", manually_triggered=True) ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 6ddd59dc4e71..ec8f5e8dac48 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -158,6 +158,37 @@ def scan(self, **kwargs): assert task.retry.calls == [pretend.call(exc=exc)] +class TestRunScheduledCheck: + def test_invalid_check_name(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + task = pretend.stub() + with pytest.raises(AttributeError): + tasks.run_scheduled_check(task, db_request, "DoesNotExist") + + def test_run_check(self, db_session, capfd, monkeypatch): + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.Enabled + ) + + request = pretend.stub( + db=db_session, + find_service_factory=pretend.call_recorder( + lambda interface: PrinterMalwareCheckService.create_service + ), + ) + + task = pretend.stub() + + tasks.run_scheduled_check(task, request, "ExampleScheduledCheck") + + assert request.find_service_factory.calls == [ + pretend.call(tasks.IMalwareCheckService) + ] + + out, err = capfd.readouterr() + assert out == "ExampleScheduledCheck {'manually_triggered': False}\n" + + class TestBackfill: def test_invalid_check_name(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index 6b203d88a20a..817c1fd66b7a 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,7 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType -from warehouse.malware.tasks import backfill, remove_verdicts, run_check +from warehouse.malware.tasks import backfill, remove_verdicts, run_scheduled_check EVALUATION_RUN_SIZE = 10000 @@ -94,7 +94,7 @@ def run_evaluation(request): else: request.session.flash(f"Running {check.name} now!", queue="success") - request.task(run_check).delay(check.name, manually_triggered=True) + request.task(run_scheduled_check).delay(check.name, manually_triggered=True) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index a1b76b82f2b5..2336c9f11dce 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -17,9 +17,10 @@ import warehouse.malware.checks as checks from warehouse import db -from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService -from warehouse.malware.tasks import run_check +from warehouse.malware.models import MalwareCheckObjectType +from warehouse.malware.tasks import run_scheduled_check +from warehouse.malware.utils import get_enabled_hooked_checks @db.listens_for(db.Session, "after_flush") @@ -31,13 +32,13 @@ def determine_malware_checks(config, session, flush_context): [ obj.__class__.__name__ for obj in session.new - if obj.__class__.__name__ in utils.valid_check_types() + if obj.__class__.__name__ in MalwareCheckObjectType.__members__ ] ): return malware_checks = session.info.setdefault("warehouse.malware.checks", set()) - enabled_checks = utils.get_enabled_hooked_checks(session) + enabled_checks = get_enabled_hooked_checks(session) for obj in session.new: for check_name in enabled_checks.get(obj.__class__.__name__, []): malware_checks.update([f"{check_name}:{obj.id}"]) @@ -71,5 +72,5 @@ def includeme(config): check = check_obj[1] if check.check_type == "scheduled": config.add_periodic_task( - crontab(**check.schedule), run_check, args=(check_obj[0],) + crontab(**check.schedule), run_scheduled_check, args=(check_obj[0],) ) diff --git a/warehouse/malware/checks/setup_patterns/check.py b/warehouse/malware/checks/setup_patterns/check.py index 9cb74061d0e1..2a92a36ed9a7 100644 --- a/warehouse/malware/checks/setup_patterns/check.py +++ b/warehouse/malware/checks/setup_patterns/check.py @@ -17,11 +17,8 @@ import yara from warehouse.malware.checks.base import MalwareCheckBase -from warehouse.malware.checks.utils import ( - FatalCheckException, - extract_file_content, - fetch_url_content, -) +from warehouse.malware.checks.utils import extract_file_content, fetch_url_content +from warehouse.malware.errors import FatalCheckException from warehouse.malware.models import VerdictClassification, VerdictConfidence diff --git a/warehouse/malware/checks/utils.py b/warehouse/malware/checks/utils.py index d3b4acc7d908..5ddda01ccc7c 100644 --- a/warehouse/malware/checks/utils.py +++ b/warehouse/malware/checks/utils.py @@ -78,7 +78,3 @@ def extract_file_content(archive_stream, file_path): return None except tarfile.TarError: return None - - -class FatalCheckException(Exception): - pass diff --git a/warehouse/malware/errors.py b/warehouse/malware/errors.py new file mode 100644 index 000000000000..837c079cef18 --- /dev/null +++ b/warehouse/malware/errors.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class FatalCheckException(Exception): + pass diff --git a/warehouse/malware/services.py b/warehouse/malware/services.py index ccb36723b345..5566250bbad8 100644 --- a/warehouse/malware/services.py +++ b/warehouse/malware/services.py @@ -41,5 +41,11 @@ def create_service(cls, context, request): def run_checks(self, checks, **kwargs): for check_info in checks: - check_name, obj_id = check_info.split(":") - self.executor(check_name, obj_id, **kwargs) + # Hooked checks + if ":" in check_info: + check_name, obj_id = check_info.split(":") + kwargs["obj_id"] = obj_id + # Scheduled checks + else: + check_name = check_info + self.executor(check_name, **kwargs) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 7355df973f06..ac27f1577b0b 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -17,7 +17,7 @@ import warehouse.malware.checks as checks import warehouse.packaging.models as packaging_models -from warehouse.malware.checks.utils import FatalCheckException +from warehouse.malware.errors import FatalCheckException from warehouse.malware.interfaces import IMalwareCheckService from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields @@ -57,6 +57,13 @@ def run_check(task, request, check_name, obj_id=None, manually_triggered=False): raise task.retry(exc=exc) +@task(bind=True, ignore_result=True, acks_late=True) +def run_scheduled_check(task, request, check_name, manually_triggered=False): + malware_check_service = request.find_service_factory(IMalwareCheckService) + malware_check = malware_check_service(None, request) + malware_check.run_checks([check_name], manually_triggered=manually_triggered) + + @task(bind=True, ignore_result=True, acks_late=True) def backfill(task, request, check_name, num_objects): """ diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index 879d9e8a3a7b..95ebf4b796b5 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -10,21 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools - from collections import defaultdict -from warehouse.malware.models import ( - MalwareCheck, - MalwareCheckObjectType, - MalwareCheckState, - MalwareCheckType, -) - - -@functools.lru_cache() -def valid_check_types(): - return set([t.value for t in MalwareCheckObjectType]) +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType def get_check_fields(check):