Skip to content

Commit

Permalink
Misc cleanup and TODOs on malware checks. (#7355)
Browse files Browse the repository at this point in the history
* Misc cleanup and TODOs on malware checks.

    - Change backfill function to invoke `IMalwareCheckService` interface
    - Add support for `kwargs to `IMalwareCheckService` interface
    - Rename variable from reserved word `file` to `release_file`
    - Add `FatalCheckException` for non-retryable exceptions
    - Replace `MALWARE_CHECK_BACKEND` in dev/environment

* Make `IMalwareService` the entrypoint for `run_check`

- Add `run_scheduled_check` task that invokes this interface.
- Remove useless utility method
- Move `FatalCheckException` into warehouse/malware/errors.py.
  • Loading branch information
xmunoz authored Feb 7, 2020
1 parent 32ed6ed commit 1b392c6
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 68 deletions.
3 changes: 1 addition & 2 deletions dev/environment
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa

BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService

#TODO: change this to PrinterMalwareCheckService before deploy
MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService
MALWARE_CHECK_BACKEND=warehouse.malware.services.PrinterMalwareCheckService

METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog

Expand Down
3 changes: 2 additions & 1 deletion tests/common/checks/hooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.errors import FatalCheckException
from warehouse.malware.models import VerdictClassification, VerdictConfidence


Expand All @@ -29,7 +30,7 @@ def __init__(self, db):
def scan(self, **kwargs):
file_id = kwargs.get("obj_id")
if file_id is None:
return
raise FatalCheckException("Missing required kwarg `obj_id`")

self.add_verdict(
file_id=file_id,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/admin/views/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from warehouse.admin.views import checks as views
from warehouse.malware.models import MalwareCheckState, MalwareCheckType
from warehouse.malware.tasks import backfill, run_check
from warehouse.malware.tasks import backfill, run_scheduled_check

from ....common.db.malware import MalwareCheckFactory

Expand Down Expand Up @@ -208,7 +208,7 @@ def test_success(self, db_request, check_type):
assert db_request.session.flash.calls == [
pretend.call("Running %s now!" % check.name, queue="success",)
]
assert db_request.task.calls == [pretend.call(run_check)]
assert db_request.task.calls == [pretend.call(run_scheduled_check)]
assert backfill_recorder.delay.calls == [
pretend.call(check.name, manually_triggered=True)
]
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/malware/checks/setup_patterns/test_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def test_scan_missing_kwargs(db_session, obj, file_url):
name="SetupPatternCheck", state=MalwareCheckState.Enabled
)
check = c.SetupPatternCheck(db_session)
check.scan(obj=obj, file_url=file_url)

assert check._verdicts == []
with pytest.raises(c.FatalCheckException):
check.scan(obj=obj, file_url=file_url)


def test_scan_non_sdist(db_session):
Expand Down
17 changes: 9 additions & 8 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from celery.schedules import crontab

from warehouse import malware
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.tasks import run_check
from warehouse.malware.tasks import run_scheduled_check

from ...common import checks as test_checks
from ...common.db.accounts import UserFactory
Expand All @@ -30,7 +29,7 @@ def test_determine_malware_checks_no_checks(monkeypatch, db_request):
def get_enabled_hooked_checks(session):
return defaultdict(list)

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -49,7 +48,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -68,7 +67,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

user = UserFactory.create()

Expand All @@ -85,7 +84,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -105,7 +104,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand Down Expand Up @@ -179,6 +178,8 @@ def test_includeme(monkeypatch):

assert config.add_periodic_task.calls == [
pretend.call(
crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",)
crontab(minute="0", hour="*/8"),
run_scheduled_check,
args=("ExampleScheduledCheck",),
)
]
40 changes: 33 additions & 7 deletions tests/unit/malware/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.

import pretend
import pytest

from zope.interface.verify import verifyClass

Expand All @@ -31,13 +32,14 @@ def test_create_service(self):
service = PrinterMalwareCheckService.create_service(None, request)
assert service.executor == print

def test_run_checks(self, capfd):
@pytest.mark.parametrize(("kwargs"), [{}, {"manually_triggered": True}])
def test_run_checks(self, capfd, kwargs):
request = pretend.stub()
service = PrinterMalwareCheckService.create_service(None, request)
checks = ["one", "two", "three"]
service.run_checks(checks)
service.run_checks(checks, **kwargs)
out, err = capfd.readouterr()
assert out == "one\ntwo\nthree\n"
assert out == "".join(["%s %s\n" % (check, kwargs) for check in checks])


class TestDatabaseMalwareService:
Expand All @@ -50,12 +52,36 @@ def test_create_service(self, db_request):
service = DatabaseMalwareCheckService.create_service(None, db_request)
assert service.executor == db_request.task(run_check).delay

def test_run_checks(self, db_request):
_delay = pretend.call_recorder(lambda *args: None)
def test_run_hooked_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = [
"MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187",
"AnotherCheck:44f57b0e-c5b0-47c5-8713-341cf392efe2",
"FinalCheck:e8518a15-8f01-430e-8f5b-87644007c9c0",
]
service.run_checks(checks)
assert _delay.calls == [
pretend.call("MyTestCheck", obj_id="ba70267f-fabf-496f-9ac2-d237a983b187"),
pretend.call("AnotherCheck", obj_id="44f57b0e-c5b0-47c5-8713-341cf392efe2"),
pretend.call("FinalCheck", obj_id="e8518a15-8f01-430e-8f5b-87644007c9c0"),
]

def test_run_scheduled_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"]
checks = ["MyTestScheduledCheck"]
service.run_checks(checks)
assert _delay.calls == [pretend.call("MyTestScheduledCheck")]

def test_run_triggered_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = ["MyTriggeredCheck"]
service.run_checks(checks, manually_triggered=True)
assert _delay.calls == [
pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187")
pretend.call("MyTriggeredCheck", manually_triggered=True)
]
86 changes: 77 additions & 9 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from warehouse.malware import tasks
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.services import PrinterMalwareCheckService

from ...common import checks as test_checks
from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory
Expand Down Expand Up @@ -101,6 +102,28 @@ def test_missing_check(self, db_request, monkeypatch):
task, db_request, "DoesNotExistCheck",
)

def test_missing_obj_id(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
task = pretend.stub()

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)
task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)),
)

tasks.run_check(task, request, "ExampleHookedCheck")

assert request.log.error.calls == [
pretend.call(
"Fatal exception: ExampleHookedCheck: Missing required kwarg `obj_id`"
)
]

def test_retry(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
exc = Exception("Scan failed")
Expand Down Expand Up @@ -135,6 +158,37 @@ def scan(self, **kwargs):
assert task.retry.calls == [pretend.call(exc=exc)]


class TestRunScheduledCheck:
def test_invalid_check_name(self, db_request, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
task = pretend.stub()
with pytest.raises(AttributeError):
tasks.run_scheduled_check(task, db_request, "DoesNotExist")

def test_run_check(self, db_session, capfd, monkeypatch):
MalwareCheckFactory.create(
name="ExampleScheduledCheck", state=MalwareCheckState.Enabled
)

request = pretend.stub(
db=db_session,
find_service_factory=pretend.call_recorder(
lambda interface: PrinterMalwareCheckService.create_service
),
)

task = pretend.stub()

tasks.run_scheduled_check(task, request, "ExampleScheduledCheck")

assert request.find_service_factory.calls == [
pretend.call(tasks.IMalwareCheckService)
]

out, err = capfd.readouterr()
assert out == "ExampleScheduledCheck {'manually_triggered': False}\n"


class TestBackfill:
def test_invalid_check_name(self, db_request, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
Expand All @@ -145,33 +199,47 @@ def test_invalid_check_name(self, db_request, monkeypatch):
@pytest.mark.parametrize(
("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)]
)
def test_run(self, db_session, num_objects, num_runs, monkeypatch):
def test_run(self, db_session, capfd, num_objects, num_runs, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)

ids = []
for i in range(num_objects):
FileFactory.create()
ids.append(FileFactory.create().id)

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)

enqueue_recorder = pretend.stub(
delay=pretend.call_recorder(lambda *a, **kw: None)
)
task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder)

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)),
task=task,
find_service_factory=pretend.call_recorder(
lambda interface: PrinterMalwareCheckService.create_service
),
)

task = pretend.stub()

tasks.backfill(task, request, "ExampleHookedCheck", num_runs)

assert request.log.info.calls == [
pretend.call("Running backfill on %d Files." % num_runs)
]

assert len(enqueue_recorder.delay.calls) == num_runs
assert request.find_service_factory.calls == [
pretend.call(tasks.IMalwareCheckService)
]

out, err = capfd.readouterr()
num_output_lines = 0
for file_id in ids:
logged_output = "ExampleHookedCheck:%s %s\n" % (
file_id,
{"manually_triggered": True},
)
num_output_lines += 1 if logged_output in out else 0

assert num_output_lines == num_runs


class TestSyncChecks:
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/views/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlalchemy.orm.exc import NoResultFound

from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType
from warehouse.malware.tasks import backfill, remove_verdicts, run_check
from warehouse.malware.tasks import backfill, remove_verdicts, run_scheduled_check

EVALUATION_RUN_SIZE = 10000

Expand Down Expand Up @@ -94,7 +94,7 @@ def run_evaluation(request):

else:
request.session.flash(f"Running {check.name} now!", queue="success")
request.task(run_check).delay(check.name, manually_triggered=True)
request.task(run_scheduled_check).delay(check.name, manually_triggered=True)

return HTTPSeeOther(
request.route_path("admin.checks.detail", check_name=check.name)
Expand Down
11 changes: 6 additions & 5 deletions warehouse/malware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import warehouse.malware.checks as checks

from warehouse import db
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.tasks import run_check
from warehouse.malware.models import MalwareCheckObjectType
from warehouse.malware.tasks import run_scheduled_check
from warehouse.malware.utils import get_enabled_hooked_checks


@db.listens_for(db.Session, "after_flush")
Expand All @@ -31,13 +32,13 @@ def determine_malware_checks(config, session, flush_context):
[
obj.__class__.__name__
for obj in session.new
if obj.__class__.__name__ in utils.valid_check_types()
if obj.__class__.__name__ in MalwareCheckObjectType.__members__
]
):
return

malware_checks = session.info.setdefault("warehouse.malware.checks", set())
enabled_checks = utils.get_enabled_hooked_checks(session)
enabled_checks = get_enabled_hooked_checks(session)
for obj in session.new:
for check_name in enabled_checks.get(obj.__class__.__name__, []):
malware_checks.update([f"{check_name}:{obj.id}"])
Expand Down Expand Up @@ -71,5 +72,5 @@ def includeme(config):
check = check_obj[1]
if check.check_type == "scheduled":
config.add_periodic_task(
crontab(**check.schedule), run_check, args=(check_obj[0],)
crontab(**check.schedule), run_scheduled_check, args=(check_obj[0],)
)
Loading

0 comments on commit 1b392c6

Please sign in to comment.