Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc cleanup and TODOs on malware checks. #7355

Merged
merged 2 commits into from
Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dev/environment
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa

BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService

#TODO: change this to PrinterMalwareCheckService before deploy
MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService
MALWARE_CHECK_BACKEND=warehouse.malware.services.PrinterMalwareCheckService

METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog

Expand Down
3 changes: 2 additions & 1 deletion tests/common/checks/hooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.errors import FatalCheckException
from warehouse.malware.models import VerdictClassification, VerdictConfidence


Expand All @@ -29,7 +30,7 @@ def __init__(self, db):
def scan(self, **kwargs):
file_id = kwargs.get("obj_id")
if file_id is None:
return
raise FatalCheckException("Missing required kwarg `obj_id`")

self.add_verdict(
file_id=file_id,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/admin/views/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from warehouse.admin.views import checks as views
from warehouse.malware.models import MalwareCheckState, MalwareCheckType
from warehouse.malware.tasks import backfill, run_check
from warehouse.malware.tasks import backfill, run_scheduled_check

from ....common.db.malware import MalwareCheckFactory

Expand Down Expand Up @@ -208,7 +208,7 @@ def test_success(self, db_request, check_type):
assert db_request.session.flash.calls == [
pretend.call("Running %s now!" % check.name, queue="success",)
]
assert db_request.task.calls == [pretend.call(run_check)]
assert db_request.task.calls == [pretend.call(run_scheduled_check)]
assert backfill_recorder.delay.calls == [
pretend.call(check.name, manually_triggered=True)
]
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/malware/checks/setup_patterns/test_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def test_scan_missing_kwargs(db_session, obj, file_url):
name="SetupPatternCheck", state=MalwareCheckState.Enabled
)
check = c.SetupPatternCheck(db_session)
check.scan(obj=obj, file_url=file_url)

assert check._verdicts == []
with pytest.raises(c.FatalCheckException):
check.scan(obj=obj, file_url=file_url)


def test_scan_non_sdist(db_session):
Expand Down
17 changes: 9 additions & 8 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from celery.schedules import crontab

from warehouse import malware
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.tasks import run_check
from warehouse.malware.tasks import run_scheduled_check

from ...common import checks as test_checks
from ...common.db.accounts import UserFactory
Expand All @@ -30,7 +29,7 @@ def test_determine_malware_checks_no_checks(monkeypatch, db_request):
def get_enabled_hooked_checks(session):
return defaultdict(list)

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -49,7 +48,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -68,7 +67,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

user = UserFactory.create()

Expand All @@ -85,7 +84,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -105,7 +104,7 @@ def get_enabled_hooked_checks(session):
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)
monkeypatch.setattr(malware, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand Down Expand Up @@ -179,6 +178,8 @@ def test_includeme(monkeypatch):

assert config.add_periodic_task.calls == [
pretend.call(
crontab(minute="0", hour="*/8"), run_check, args=("ExampleScheduledCheck",)
crontab(minute="0", hour="*/8"),
run_scheduled_check,
args=("ExampleScheduledCheck",),
)
]
40 changes: 33 additions & 7 deletions tests/unit/malware/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.

import pretend
import pytest

from zope.interface.verify import verifyClass

Expand All @@ -31,13 +32,14 @@ def test_create_service(self):
service = PrinterMalwareCheckService.create_service(None, request)
assert service.executor == print

def test_run_checks(self, capfd):
@pytest.mark.parametrize(("kwargs"), [{}, {"manually_triggered": True}])
def test_run_checks(self, capfd, kwargs):
xmunoz marked this conversation as resolved.
Show resolved Hide resolved
request = pretend.stub()
service = PrinterMalwareCheckService.create_service(None, request)
checks = ["one", "two", "three"]
service.run_checks(checks)
service.run_checks(checks, **kwargs)
out, err = capfd.readouterr()
assert out == "one\ntwo\nthree\n"
assert out == "".join(["%s %s\n" % (check, kwargs) for check in checks])


class TestDatabaseMalwareService:
Expand All @@ -50,12 +52,36 @@ def test_create_service(self, db_request):
service = DatabaseMalwareCheckService.create_service(None, db_request)
assert service.executor == db_request.task(run_check).delay

def test_run_checks(self, db_request):
_delay = pretend.call_recorder(lambda *args: None)
def test_run_hooked_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = [
"MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187",
"AnotherCheck:44f57b0e-c5b0-47c5-8713-341cf392efe2",
"FinalCheck:e8518a15-8f01-430e-8f5b-87644007c9c0",
]
service.run_checks(checks)
assert _delay.calls == [
pretend.call("MyTestCheck", obj_id="ba70267f-fabf-496f-9ac2-d237a983b187"),
pretend.call("AnotherCheck", obj_id="44f57b0e-c5b0-47c5-8713-341cf392efe2"),
pretend.call("FinalCheck", obj_id="e8518a15-8f01-430e-8f5b-87644007c9c0"),
]

def test_run_scheduled_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"]
checks = ["MyTestScheduledCheck"]
service.run_checks(checks)
assert _delay.calls == [pretend.call("MyTestScheduledCheck")]

def test_run_triggered_check(self, db_request):
_delay = pretend.call_recorder(lambda *args, **kwargs: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = ["MyTriggeredCheck"]
service.run_checks(checks, manually_triggered=True)
assert _delay.calls == [
pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187")
pretend.call("MyTriggeredCheck", manually_triggered=True)
]
86 changes: 77 additions & 9 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from warehouse.malware import tasks
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.services import PrinterMalwareCheckService

from ...common import checks as test_checks
from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory
Expand Down Expand Up @@ -101,6 +102,28 @@ def test_missing_check(self, db_request, monkeypatch):
task, db_request, "DoesNotExistCheck",
)

def test_missing_obj_id(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
task = pretend.stub()

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)
task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None)),
)

tasks.run_check(task, request, "ExampleHookedCheck")

assert request.log.error.calls == [
pretend.call(
"Fatal exception: ExampleHookedCheck: Missing required kwarg `obj_id`"
)
]

def test_retry(self, db_session, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
exc = Exception("Scan failed")
Expand Down Expand Up @@ -135,6 +158,37 @@ def scan(self, **kwargs):
assert task.retry.calls == [pretend.call(exc=exc)]


class TestRunScheduledCheck:
def test_invalid_check_name(self, db_request, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
task = pretend.stub()
with pytest.raises(AttributeError):
tasks.run_scheduled_check(task, db_request, "DoesNotExist")

def test_run_check(self, db_session, capfd, monkeypatch):
MalwareCheckFactory.create(
name="ExampleScheduledCheck", state=MalwareCheckState.Enabled
)

request = pretend.stub(
db=db_session,
find_service_factory=pretend.call_recorder(
lambda interface: PrinterMalwareCheckService.create_service
),
)

task = pretend.stub()

tasks.run_scheduled_check(task, request, "ExampleScheduledCheck")

assert request.find_service_factory.calls == [
pretend.call(tasks.IMalwareCheckService)
]

out, err = capfd.readouterr()
assert out == "ExampleScheduledCheck {'manually_triggered': False}\n"


class TestBackfill:
def test_invalid_check_name(self, db_request, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)
Expand All @@ -145,33 +199,47 @@ def test_invalid_check_name(self, db_request, monkeypatch):
@pytest.mark.parametrize(
("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)]
)
def test_run(self, db_session, num_objects, num_runs, monkeypatch):
def test_run(self, db_session, capfd, num_objects, num_runs, monkeypatch):
monkeypatch.setattr(tasks, "checks", test_checks)

ids = []
for i in range(num_objects):
FileFactory.create()
ids.append(FileFactory.create().id)

MalwareCheckFactory.create(
name="ExampleHookedCheck", state=MalwareCheckState.Enabled
)

enqueue_recorder = pretend.stub(
delay=pretend.call_recorder(lambda *a, **kw: None)
)
task = pretend.call_recorder(lambda *args, **kwargs: enqueue_recorder)

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None)),
task=task,
find_service_factory=pretend.call_recorder(
lambda interface: PrinterMalwareCheckService.create_service
),
)

task = pretend.stub()

tasks.backfill(task, request, "ExampleHookedCheck", num_runs)

assert request.log.info.calls == [
pretend.call("Running backfill on %d Files." % num_runs)
]

assert len(enqueue_recorder.delay.calls) == num_runs
assert request.find_service_factory.calls == [
pretend.call(tasks.IMalwareCheckService)
]

out, err = capfd.readouterr()
num_output_lines = 0
for file_id in ids:
logged_output = "ExampleHookedCheck:%s %s\n" % (
file_id,
{"manually_triggered": True},
)
num_output_lines += 1 if logged_output in out else 0

assert num_output_lines == num_runs


class TestSyncChecks:
Expand Down
4 changes: 2 additions & 2 deletions warehouse/admin/views/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlalchemy.orm.exc import NoResultFound

from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType
from warehouse.malware.tasks import backfill, remove_verdicts, run_check
from warehouse.malware.tasks import backfill, remove_verdicts, run_scheduled_check

EVALUATION_RUN_SIZE = 10000

Expand Down Expand Up @@ -94,7 +94,7 @@ def run_evaluation(request):

else:
request.session.flash(f"Running {check.name} now!", queue="success")
request.task(run_check).delay(check.name, manually_triggered=True)
request.task(run_scheduled_check).delay(check.name, manually_triggered=True)

return HTTPSeeOther(
request.route_path("admin.checks.detail", check_name=check.name)
Expand Down
11 changes: 6 additions & 5 deletions warehouse/malware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import warehouse.malware.checks as checks

from warehouse import db
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.tasks import run_check
from warehouse.malware.models import MalwareCheckObjectType
from warehouse.malware.tasks import run_scheduled_check
from warehouse.malware.utils import get_enabled_hooked_checks


@db.listens_for(db.Session, "after_flush")
Expand All @@ -31,13 +32,13 @@ def determine_malware_checks(config, session, flush_context):
[
obj.__class__.__name__
for obj in session.new
if obj.__class__.__name__ in utils.valid_check_types()
if obj.__class__.__name__ in MalwareCheckObjectType.__members__
]
):
return

malware_checks = session.info.setdefault("warehouse.malware.checks", set())
enabled_checks = utils.get_enabled_hooked_checks(session)
enabled_checks = get_enabled_hooked_checks(session)
for obj in session.new:
for check_name in enabled_checks.get(obj.__class__.__name__, []):
malware_checks.update([f"{check_name}:{obj.id}"])
Expand Down Expand Up @@ -71,5 +72,5 @@ def includeme(config):
check = check_obj[1]
if check.check_type == "scheduled":
config.add_periodic_task(
crontab(**check.schedule), run_check, args=(check_obj[0],)
crontab(**check.schedule), run_scheduled_check, args=(check_obj[0],)
)
Loading