diff --git a/tests/common/checks/scheduled.py b/tests/common/checks/scheduled.py index 128ce102a83b..d5c80962e45c 100644 --- a/tests/common/checks/scheduled.py +++ b/tests/common/checks/scheduled.py @@ -27,11 +27,11 @@ class ExampleScheduledCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def scan(self): + def scan(self, **kwargs): project = self.db.query(Project).first() self.add_verdict( project_id=project.id, - classification=VerdictClassification.benign, + classification=VerdictClassification.Benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 28538ad12dac..451b6ad5a9e0 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -133,8 +133,8 @@ def test_includeme(): domain=warehouse, ), pretend.call( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index c8fa6512aeaa..0c3baeb03369 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -16,7 +16,8 @@ from pyramid.httpexceptions import HTTPNotFound from warehouse.admin.views import checks as views -from warehouse.malware.models import MalwareCheckState +from warehouse.malware.models import MalwareCheckState, MalwareCheckType +from warehouse.malware.tasks import backfill, run_check from ....common.db.malware import MalwareCheckFactory @@ -46,6 +47,7 @@ def test_get_check(self, db_request): "check": check, "checks": [check], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_many_versions(self, db_request): @@ -56,6 +58,7 @@ def test_get_check_many_versions(self, db_request): "check": check2, "checks": [check2, check1], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_not_found(self, db_request): @@ -129,17 +132,17 @@ def test_change_to_invalid_state(self, db_request): assert check.state == initial_state -class TestRunBackfill: +class TestRunEvaluation: @pytest.mark.parametrize( ("check_state", "message"), [ ( MalwareCheckState.Disabled, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ( MalwareCheckState.WipedOut, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ], ) @@ -152,15 +155,21 @@ def test_invalid_backfill_parameters(self, db_request, check_state, message): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) - views.run_backfill(db_request) + views.run_evaluation(db_request) assert db_request.session.flash.calls == [pretend.call(message, queue="error")] - def test_sucess(self, db_request): - check = MalwareCheckFactory.create(state=MalwareCheckState.Enabled) + @pytest.mark.parametrize( + ("check_type"), [MalwareCheckType.EventHook, MalwareCheckType.Scheduled] + ) + def test_success(self, db_request, check_type): + + check = MalwareCheckFactory.create( + check_type=check_type, state=MalwareCheckState.Enabled + ) db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -168,7 +177,7 @@ def test_sucess(self, db_request): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) backfill_recorder = pretend.stub( @@ -177,13 +186,25 @@ def test_sucess(self, db_request): db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) - views.run_backfill(db_request) - - assert db_request.session.flash.calls == [ - pretend.call( - "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), - queue="success", - ) - ] - - assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + views.run_evaluation(db_request) + + if check_type == MalwareCheckType.EventHook: + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" + % (check.name, check.hooked_object.value), + queue="success", + ) + ] + assert db_request.task.calls == [pretend.call(backfill)] + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + elif check_type == MalwareCheckType.Scheduled: + assert db_request.session.flash.calls == [ + pretend.call("Running %s now!" % check.name, queue="success",) + ] + assert db_request.task.calls == [pretend.call(run_check)] + assert backfill_recorder.delay.calls == [ + pretend.call(check.name, manually_triggered=True) + ] + else: + raise Exception("Invalid check type: %s" % check_type) diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index c8e1f3d7acf4..32628118398d 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -14,9 +14,12 @@ import pretend +from celery.schedules import crontab + from warehouse import malware from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check from ...common import checks as test_checks from ...common.db.accounts import UserFactory @@ -165,6 +168,7 @@ def test_includeme(monkeypatch): registry=pretend.stub( settings={"malware_check.backend": "TestMalwareCheckService"} ), + add_periodic_task=pretend.call_recorder(lambda *a, **kw: None), ) malware.includeme(config) @@ -172,3 +176,9 @@ def test_includeme(monkeypatch): assert config.register_service_factory.calls == [ pretend.call(malware_check_class.create_service, IMalwareCheckService) ] + + assert config.add_periodic_task.calls == [ + pretend.call( + crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",) + ) + ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index de278f4ad825..6a333c31b018 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -14,8 +14,6 @@ import pretend import pytest -from sqlalchemy.orm.exc import NoResultFound - from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict @@ -34,45 +32,86 @@ def test_success(self, db_request, monkeypatch): name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub() - tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) + tasks.run_check(task, db_request, "ExampleHookedCheck", obj_id=file0.id) assert db_request.route_url.calls == [ pretend.call("packaging.file", path=file0.path) ] assert db_request.db.query(MalwareVerdict).one() - def test_disabled_check(self, db_request, monkeypatch): + @pytest.mark.parametrize(("manually_triggered"), [True, False]) + def test_evaluation_run(self, db_session, monkeypatch, manually_triggered): + monkeypatch.setattr(tasks, "checks", test_checks) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.Evaluation + ) + ProjectFactory.create() + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) + + tasks.run_check( + task, + request, + "ExampleScheduledCheck", + manually_triggered=manually_triggered, + ) + + if manually_triggered: + assert db_session.query(MalwareVerdict).one() + else: + assert request.log.info.calls == [ + pretend.call( + "ExampleScheduledCheck is in the `evaluation` state and must be \ +manually triggered to run." + ) + ] + assert db_session.query(MalwareVerdict).all() == [] + + def test_disabled_check(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) task = pretend.stub() + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) file = FileFactory.create() - with pytest.raises(NoResultFound): - tasks.run_check(task, db_request, "ExampleHookedCheck", file.id) + tasks.run_check( + task, request, "ExampleHookedCheck", obj_id=file.id, + ) + + assert request.log.info.calls == [ + pretend.call("Check ExampleHookedCheck isn't active. Aborting.") + ] def test_missing_check(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() - file = FileFactory.create() - with pytest.raises(AttributeError): - tasks.run_check(task, db_request, "DoesNotExistCheck", file.id) + tasks.run_check( + task, db_request, "DoesNotExistCheck", + ) def test_retry(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) exc = Exception("Scan failed") def scan(self, **kwargs): raise exc - monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.Evaluation + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub( @@ -87,7 +126,7 @@ def scan(self, **kwargs): file = FileFactory.create() with pytest.raises(celery.exceptions.Retry): - tasks.run_check(task, request, "ExampleHookedCheck", file.id) + tasks.run_check(task, request, "ExampleHookedCheck", obj_id=file.id) assert request.log.error.calls == [ pretend.call("Error executing check ExampleHookedCheck: Scan failed") @@ -108,9 +147,8 @@ def test_invalid_check_name(self, db_request, monkeypatch): ) def test_run(self, db_session, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) - files = [] for i in range(num_objects): - files.append(FileFactory.create()) + FileFactory.create() MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Enabled @@ -133,15 +171,14 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch): pretend.call("Running backfill on %d Files." % num_runs) ] - assert enqueue_recorder.delay.calls == [ - pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs) - ] + assert len(enqueue_recorder.delay.calls) == num_runs class TestSyncChecks: def test_no_updates(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) + MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 2788f51519bf..2b8ca93a3541 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -140,8 +140,8 @@ def includeme(config): domain=warehouse, ) config.add_route( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 77914488bb51..c52a5f354530 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -30,12 +30,22 @@
{{ c.schedule }}