From d75af1d2f1fbabfab606e904637f5f849ac389d6 Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 30 Jan 2020 13:13:28 -0800 Subject: [PATCH] Implement scheduled checks #7093 (#7271) * Implement scheduled checks #7093 - Rename `run_backfill` to `run_evaluation` in admin malware view - Modify `run` and `scan` method signatures to accept `**kwargs` - Extend `run_check` to accomodate scheduled check functionality * Reduce unit test flakiness * Code review changes. Also replace `check.hooked_object` with `check.hooked_object.value` in check detail template. * tests, warehouse: enum fixes * Fix lint error Co-authored-by: William Woodruff --- tests/common/checks/scheduled.py | 4 +- tests/unit/admin/test_routes.py | 4 +- tests/unit/admin/views/test_checks.py | 59 ++++++++++----- tests/unit/malware/test_init.py | 10 +++ tests/unit/malware/test_tasks.py | 71 ++++++++++++++----- warehouse/admin/routes.py | 4 +- .../admin/malware/checks/detail.html | 18 ++++- .../templates/admin/malware/checks/index.html | 4 +- warehouse/admin/views/checks.py | 35 +++++---- warehouse/malware/__init__.py | 16 +++++ warehouse/malware/checks/base.py | 10 +-- warehouse/malware/tasks.py | 29 ++++++-- 12 files changed, 197 insertions(+), 67 deletions(-) diff --git a/tests/common/checks/scheduled.py b/tests/common/checks/scheduled.py index 128ce102a83b..d5c80962e45c 100644 --- a/tests/common/checks/scheduled.py +++ b/tests/common/checks/scheduled.py @@ -27,11 +27,11 @@ class ExampleScheduledCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def scan(self): + def scan(self, **kwargs): project = self.db.query(Project).first() self.add_verdict( project_id=project.id, - classification=VerdictClassification.benign, + classification=VerdictClassification.Benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 28538ad12dac..451b6ad5a9e0 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -133,8 +133,8 @@ def test_includeme(): domain=warehouse, ), pretend.call( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ), pretend.call("admin.verdicts.list", "/admin/verdicts/", domain=warehouse), diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index c8fa6512aeaa..0c3baeb03369 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -16,7 +16,8 @@ from pyramid.httpexceptions import HTTPNotFound from warehouse.admin.views import checks as views -from warehouse.malware.models import MalwareCheckState +from warehouse.malware.models import MalwareCheckState, MalwareCheckType +from warehouse.malware.tasks import backfill, run_check from ....common.db.malware import MalwareCheckFactory @@ -46,6 +47,7 @@ def test_get_check(self, db_request): "check": check, "checks": [check], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_many_versions(self, db_request): @@ -56,6 +58,7 @@ def test_get_check_many_versions(self, db_request): "check": check2, "checks": [check2, check1], "states": MalwareCheckState, + "evaluation_run_size": 10000, } def test_get_check_not_found(self, db_request): @@ -129,17 +132,17 @@ def test_change_to_invalid_state(self, db_request): assert check.state == initial_state -class TestRunBackfill: +class TestRunEvaluation: @pytest.mark.parametrize( ("check_state", "message"), [ ( MalwareCheckState.Disabled, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ( MalwareCheckState.WipedOut, - "Check must be in 'enabled' or 'evaluation' state to run a backfill.", + "Check must be in 'enabled' or 'evaluation' state to manually execute.", ), ], ) @@ -152,15 +155,21 @@ def test_invalid_backfill_parameters(self, db_request, check_state, message): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) - views.run_backfill(db_request) + views.run_evaluation(db_request) assert db_request.session.flash.calls == [pretend.call(message, queue="error")] - def test_sucess(self, db_request): - check = MalwareCheckFactory.create(state=MalwareCheckState.Enabled) + @pytest.mark.parametrize( + ("check_type"), [MalwareCheckType.EventHook, MalwareCheckType.Scheduled] + ) + def test_success(self, db_request, check_type): + + check = MalwareCheckFactory.create( + check_type=check_type, state=MalwareCheckState.Enabled + ) db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( @@ -168,7 +177,7 @@ def test_sucess(self, db_request): ) db_request.route_path = pretend.call_recorder( - lambda *a, **kw: "/admin/checks/%s/run_backfill" % check.name + lambda *a, **kw: "/admin/checks/%s/run_evaluation" % check.name ) backfill_recorder = pretend.stub( @@ -177,13 +186,25 @@ def test_sucess(self, db_request): db_request.task = pretend.call_recorder(lambda *a, **kw: backfill_recorder) - views.run_backfill(db_request) - - assert db_request.session.flash.calls == [ - pretend.call( - "Running %s on 10000 %ss!" % (check.name, check.hooked_object.value), - queue="success", - ) - ] - - assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + views.run_evaluation(db_request) + + if check_type == MalwareCheckType.EventHook: + assert db_request.session.flash.calls == [ + pretend.call( + "Running %s on 10000 %ss!" + % (check.name, check.hooked_object.value), + queue="success", + ) + ] + assert db_request.task.calls == [pretend.call(backfill)] + assert backfill_recorder.delay.calls == [pretend.call(check.name, 10000)] + elif check_type == MalwareCheckType.Scheduled: + assert db_request.session.flash.calls == [ + pretend.call("Running %s now!" % check.name, queue="success",) + ] + assert db_request.task.calls == [pretend.call(run_check)] + assert backfill_recorder.delay.calls == [ + pretend.call(check.name, manually_triggered=True) + ] + else: + raise Exception("Invalid check type: %s" % check_type) diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index c8e1f3d7acf4..32628118398d 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -14,9 +14,12 @@ import pretend +from celery.schedules import crontab + from warehouse import malware from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check from ...common import checks as test_checks from ...common.db.accounts import UserFactory @@ -165,6 +168,7 @@ def test_includeme(monkeypatch): registry=pretend.stub( settings={"malware_check.backend": "TestMalwareCheckService"} ), + add_periodic_task=pretend.call_recorder(lambda *a, **kw: None), ) malware.includeme(config) @@ -172,3 +176,9 @@ def test_includeme(monkeypatch): assert config.register_service_factory.calls == [ pretend.call(malware_check_class.create_service, IMalwareCheckService) ] + + assert config.add_periodic_task.calls == [ + pretend.call( + crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",) + ) + ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index de278f4ad825..6a333c31b018 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -14,8 +14,6 @@ import pretend import pytest -from sqlalchemy.orm.exc import NoResultFound - from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict @@ -34,45 +32,86 @@ def test_success(self, db_request, monkeypatch): name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub() - tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) + tasks.run_check(task, db_request, "ExampleHookedCheck", obj_id=file0.id) assert db_request.route_url.calls == [ pretend.call("packaging.file", path=file0.path) ] assert db_request.db.query(MalwareVerdict).one() - def test_disabled_check(self, db_request, monkeypatch): + @pytest.mark.parametrize(("manually_triggered"), [True, False]) + def test_evaluation_run(self, db_session, monkeypatch, manually_triggered): + monkeypatch.setattr(tasks, "checks", test_checks) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.Evaluation + ) + ProjectFactory.create() + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) + + tasks.run_check( + task, + request, + "ExampleScheduledCheck", + manually_triggered=manually_triggered, + ) + + if manually_triggered: + assert db_session.query(MalwareVerdict).one() + else: + assert request.log.info.calls == [ + pretend.call( + "ExampleScheduledCheck is in the `evaluation` state and must be \ +manually triggered to run." + ) + ] + assert db_session.query(MalwareVerdict).all() == [] + + def test_disabled_check(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) task = pretend.stub() + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)), + ) file = FileFactory.create() - with pytest.raises(NoResultFound): - tasks.run_check(task, db_request, "ExampleHookedCheck", file.id) + tasks.run_check( + task, request, "ExampleHookedCheck", obj_id=file.id, + ) + + assert request.log.info.calls == [ + pretend.call("Check ExampleHookedCheck isn't active. Aborting.") + ] def test_missing_check(self, db_request, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() - file = FileFactory.create() - with pytest.raises(AttributeError): - tasks.run_check(task, db_request, "DoesNotExistCheck", file.id) + tasks.run_check( + task, db_request, "DoesNotExistCheck", + ) def test_retry(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) exc = Exception("Scan failed") def scan(self, **kwargs): raise exc - monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) MalwareCheckFactory.create( - name="ExampleHookedCheck", state=MalwareCheckState.Evaluation + name="ExampleHookedCheck", state=MalwareCheckState.Enabled ) task = pretend.stub( @@ -87,7 +126,7 @@ def scan(self, **kwargs): file = FileFactory.create() with pytest.raises(celery.exceptions.Retry): - tasks.run_check(task, request, "ExampleHookedCheck", file.id) + tasks.run_check(task, request, "ExampleHookedCheck", obj_id=file.id) assert request.log.error.calls == [ pretend.call("Error executing check ExampleHookedCheck: Scan failed") @@ -108,9 +147,8 @@ def test_invalid_check_name(self, db_request, monkeypatch): ) def test_run(self, db_session, num_objects, num_runs, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) - files = [] for i in range(num_objects): - files.append(FileFactory.create()) + FileFactory.create() MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Enabled @@ -133,15 +171,14 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch): pretend.call("Running backfill on %d Files." % num_runs) ] - assert enqueue_recorder.delay.calls == [ - pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs) - ] + assert len(enqueue_recorder.delay.calls) == num_runs class TestSyncChecks: def test_no_updates(self, db_session, monkeypatch): monkeypatch.setattr(tasks, "checks", test_checks) monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) + MalwareCheckFactory.create( name="ExampleHookedCheck", state=MalwareCheckState.Disabled ) diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 2788f51519bf..2b8ca93a3541 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -140,8 +140,8 @@ def includeme(config): domain=warehouse, ) config.add_route( - "admin.checks.run_backfill", - "/admin/checks/{check_name}/run_backfill", + "admin.checks.run_evaluation", + "/admin/checks/{check_name}/run_evaluation", domain=warehouse, ) config.add_route("admin.verdicts.list", "/admin/verdicts/", domain=warehouse) diff --git a/warehouse/admin/templates/admin/malware/checks/detail.html b/warehouse/admin/templates/admin/malware/checks/detail.html index 77914488bb51..c52a5f354530 100644 --- a/warehouse/admin/templates/admin/malware/checks/detail.html +++ b/warehouse/admin/templates/admin/malware/checks/detail.html @@ -30,12 +30,22 @@

Revision History

Version State + {% if check.check_type.value == "event_hook" %} + Hooked Object + {% else %} + Schedule + {% endif %} Created {% for c in checks %} {{ c.version }} {{ c.state.value }} + {% if check.check_type.value == "event_hook" %} + {{ c.hooked_object.value }} + {% else %} +
{{ c.schedule }}
+ {% endif %} {{ c.created }} {% endfor %} @@ -69,10 +79,14 @@

Change State

Run Evaluation

-
+
-

Run this check against 10,000 {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

+ {% if check.check_type.value == "event_hook" %} +

Run this check against {{ evaluation_run_size }} {{ check.hooked_object.value }}s, selected at random. This is used to evaluate the efficacy of a check.

+ {% else %} +

Execute this check now.

+ {% endif %}
diff --git a/warehouse/admin/templates/admin/malware/checks/index.html b/warehouse/admin/templates/admin/malware/checks/index.html index 5717849e2579..2601bfe0f62a 100644 --- a/warehouse/admin/templates/admin/malware/checks/index.html +++ b/warehouse/admin/templates/admin/malware/checks/index.html @@ -26,7 +26,7 @@ Check Name State - Revisions + Type Last Modified Description @@ -38,7 +38,7 @@ {{ check.state.value }} - {{ check.version }} + {{ check.check_type.value }} {{ check.created }} {{ check.short_description }} diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index b23465fd4a49..0cd1761c8456 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -14,8 +14,10 @@ from pyramid.view import view_config from sqlalchemy.orm.exc import NoResultFound -from warehouse.malware.models import MalwareCheck, MalwareCheckState -from warehouse.malware.tasks import backfill, remove_verdicts +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType +from warehouse.malware.tasks import backfill, remove_verdicts, run_check + +EVALUATION_RUN_SIZE = 10000 @view_config( @@ -52,36 +54,45 @@ def get_check(request): .all() ) - return {"check": check, "checks": all_checks, "states": MalwareCheckState} + return { + "check": check, + "checks": all_checks, + "states": MalwareCheckState, + "evaluation_run_size": EVALUATION_RUN_SIZE, + } @view_config( - route_name="admin.checks.run_backfill", + route_name="admin.checks.run_evaluation", permission="admin", request_method="POST", uses_session=True, require_methods=False, require_csrf=True, ) -def run_backfill(request): +def run_evaluation(request): check = get_check_by_name(request.db, request.matchdict["check_name"]) - num_objects = 10000 if check.state not in (MalwareCheckState.Enabled, MalwareCheckState.Evaluation): request.session.flash( - f"Check must be in 'enabled' or 'evaluation' state to run a backfill.", + f"Check must be in 'enabled' or 'evaluation' state to manually execute.", queue="error", ) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) ) - request.session.flash( - f"Running {check.name} on {num_objects} {check.hooked_object.value}s!", - queue="success", - ) + if check.check_type == MalwareCheckType.EventHook: + request.session.flash( + f"Running {check.name} on {EVALUATION_RUN_SIZE} {check.hooked_object.value}s\ +!", + queue="success", + ) + request.task(backfill).delay(check.name, EVALUATION_RUN_SIZE) - request.task(backfill).delay(check.name, num_objects) + else: + request.session.flash(f"Running {check.name} now!", queue="success") + request.task(run_check).delay(check.name, manually_triggered=True) return HTTPSeeOther( request.route_path("admin.checks.detail", check_name=check.name) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index f54a9e89b4f5..a1b76b82f2b5 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -10,9 +10,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +from celery.schedules import crontab + +import warehouse.malware.checks as checks + from warehouse import db from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check @db.listens_for(db.Session, "after_flush") @@ -57,3 +64,12 @@ def includeme(config): config.register_service_factory( malware_check_class.create_service, IMalwareCheckService ) + + # Add scheduled tasks for every scheduled Malware Check. + all_checks = inspect.getmembers(checks, inspect.isclass) + for check_obj in all_checks: + check = check_obj[1] + if check.check_type == "scheduled": + config.add_periodic_task( + crontab(**check.schedule), run_check, args=(check_obj[0],) + ) diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index 67810a367203..44f8230a1c2c 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -18,7 +18,7 @@ class MalwareCheckBase: def __init__(self, db): self.db = db self._name = self.__class__.__name__ - self._load_check_id() + self._load_check_fields() self._verdicts = [] @classmethod @@ -26,7 +26,7 @@ def prepare(cls, request, obj_id): """ Prepares some context for scanning the given object. """ - kwargs = {} + kwargs = {"obj_id": obj_id} model = getattr(models, cls.hooked_object) kwargs["obj"] = request.db.query(model).get(obj_id) @@ -60,9 +60,9 @@ def backfill(self, sample=1): backfill on the entire corpus. """ - def _load_check_id(self): - (self.id,) = ( - self.db.query(MalwareCheck.id) + def _load_check_fields(self): + self.id, self.state = ( + self.db.query(MalwareCheck.id, MalwareCheck.state) .filter(MalwareCheck.name == self._name) .filter( MalwareCheck.state.in_( diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index dc0274cbee98..1fbe46b7a638 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -12,6 +12,8 @@ import inspect +from sqlalchemy.orm.exc import NoResultFound + import warehouse.malware.checks as checks import warehouse.packaging.models as packaging_models @@ -21,11 +23,30 @@ @task(bind=True, ignore_result=True, acks_late=True, retry_backoff=True) -def run_check(task, request, check_name, obj_id): - check = getattr(checks, check_name)(request.db) +def run_check(task, request, check_name, obj_id=None, manually_triggered=False): try: + check = getattr(checks, check_name)(request.db) + except NoResultFound: + request.log.info("Check %s isn't active. Aborting." % check_name) + return + + # Don't run scheduled checks if they are in evaluation mode, unless manually + # triggered. + if check.state == MalwareCheckState.Evaluation and not manually_triggered: + request.log.info( + "%s is in the `evaluation` state and must be manually triggered to run." + % check_name + ) + return + + kwargs = {} + + # Hooked checks require `obj_id`s. + if obj_id is not None: kwargs = check.prepare(request, obj_id) - check.run(obj_id=obj_id, **kwargs) + + try: + check.run(**kwargs) except Exception as exc: request.log.error("Error executing check %s: %s" % (check_name, str(exc))) raise task.retry(exc=exc) @@ -43,7 +64,7 @@ def backfill(task, request, check_name, num_objects): request.log.info("Running backfill on %d %ss." % (num_objects, check.hooked_object)) for (elem_id,) in query: - request.task(run_check).delay(check_name, elem_id) + request.task(run_check).delay(check_name, elem_id, manually_triggered=True) @task(bind=True, ignore_result=True, acks_late=True)