From aca136d73a092be519579bde692e7372b2c44ae5 Mon Sep 17 00:00:00 2001 From: Cristina Date: Wed, 8 Jan 2020 07:37:48 -0800 Subject: [PATCH] Add wipe-out functionality (#7202) * Add wipe-out functionality Related: #7133 * Call list explicitly --- tests/common/db/malware.py | 21 ++++++++-- tests/unit/admin/views/test_checks.py | 22 +++++++++-- tests/unit/malware/test_tasks.py | 56 ++++++++++++++++++++++++++- warehouse/admin/views/checks.py | 3 ++ warehouse/malware/tasks.py | 22 ++++++++++- 5 files changed, 114 insertions(+), 10 deletions(-) diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 7b01dc4723d6..8c365f4abb66 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -20,9 +20,13 @@ MalwareCheckObjectType, MalwareCheckState, MalwareCheckType, + MalwareVerdict, + VerdictClassification, + VerdictConfidence, ) from .base import WarehouseFactory +from .packaging import FileFactory class MalwareCheckFactory(WarehouseFactory): @@ -33,9 +37,20 @@ class Meta: version = 1 short_description = factory.fuzzy.FuzzyText(length=80) long_description = factory.fuzzy.FuzzyText(length=300) - check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType]) - hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType]) - state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState]) + check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType)) + hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType)) + state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState)) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) ) + + +class MalwareVerdictFactory(WarehouseFactory): + class Meta: + model = MalwareVerdict + + check = factory.SubFactory(MalwareCheckFactory) + release_file = factory.SubFactory(FileFactory) + classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification)) + confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence)) + message = factory.fuzzy.FuzzyText(length=80) diff --git a/tests/unit/admin/views/test_checks.py b/tests/unit/admin/views/test_checks.py index 55a45b2d5ac0..601e26e79858 100644 --- a/tests/unit/admin/views/test_checks.py +++ b/tests/unit/admin/views/test_checks.py @@ -67,17 +67,25 @@ def test_get_check_not_found(self, db_request): class TestChangeCheckState: - def test_change_to_enabled(self, db_request): + @pytest.mark.parametrize( + ("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out] + ) + def test_change_to_valid_state(self, db_request, final_state): check = MalwareCheckFactory.create( name="MyCheck", state=MalwareCheckState.disabled ) - db_request.POST = {"id": check.id, "check_state": "enabled"} + db_request.POST = {"id": check.id, "check_state": final_state.value} db_request.matchdict["check_name"] = check.name db_request.session = pretend.stub( flash=pretend.call_recorder(lambda *a, **kw: None) ) + wipe_out_recorder = pretend.stub( + delay=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.task = pretend.call_recorder(lambda *a, **kw: wipe_out_recorder) + db_request.route_path = pretend.call_recorder( lambda *a, **kw: "/admin/checks/MyCheck/change_state" ) @@ -85,9 +93,15 @@ def test_change_to_enabled(self, db_request): views.change_check_state(db_request) assert db_request.session.flash.calls == [ - pretend.call("Changed 'MyCheck' check to 'enabled'!", queue="success") + pretend.call( + "Changed 'MyCheck' check to '%s'!" % final_state.value, queue="success" + ) ] - assert check.state == MalwareCheckState.enabled + + assert check.state == final_state + + if final_state == MalwareCheckState.wiped_out: + assert wipe_out_recorder.delay.calls == [pretend.call("MyCheck")] def test_change_to_invalid_state(self, db_request): check = MalwareCheckFactory.create(name="MyCheck") diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 1057af6855a5..5c89cc5fd562 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -19,9 +19,9 @@ import warehouse.malware.checks as checks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import run_check, sync_checks +from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks -from ...common.db.malware import MalwareCheckFactory +from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory @@ -255,3 +255,55 @@ def test_only_wiped_out(self, db_session): from codebase." ), ] + + +class TestRemoveVerdicts: + def test_no_verdicts(self, db_session): + check = MalwareCheckFactory.create() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + task = pretend.stub() + remove_verdicts(task, request, check.name) + + assert request.log.info.calls == [ + pretend.call( + "Removing 0 malware verdicts associated with %s version 1." % check.name + ), + ] + + @pytest.mark.parametrize(("check_with_verdicts"), [True, False]) + def test_many_verdicts(self, db_session, check_with_verdicts): + check0 = MalwareCheckFactory.create() + check1 = MalwareCheckFactory.create() + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + num_verdicts = 10 + + for i in range(num_verdicts): + MalwareVerdictFactory.create(check=check1, release_file=file0) + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + task = pretend.stub() + + if check_with_verdicts: + wiped_out_check = check1 + else: + wiped_out_check = check0 + num_verdicts = 0 + + remove_verdicts(task, request, wiped_out_check.name) + + assert request.log.info.calls == [ + pretend.call( + "Removing %d malware verdicts associated with %s version 1." + % (num_verdicts, wiped_out_check.name) + ), + ] diff --git a/warehouse/admin/views/checks.py b/warehouse/admin/views/checks.py index e3d38a88a0c0..f1ad86685d94 100644 --- a/warehouse/admin/views/checks.py +++ b/warehouse/admin/views/checks.py @@ -15,6 +15,7 @@ from sqlalchemy.orm.exc import NoResultFound from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.tasks import remove_verdicts @view_config( @@ -80,6 +81,8 @@ def change_check_state(request): except (AttributeError, KeyError): request.session.flash("Invalid check state provided.", queue="error") else: + if check.state == MalwareCheckState.wiped_out: + request.task(remove_verdicts).delay(check.name) request.session.flash( f"Changed {check.name!r} check to {check.state.value!r}!", queue="success" ) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index 1548d28e66d7..0d2f570c436f 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -14,7 +14,7 @@ import warehouse.malware.checks as checks -from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -86,3 +86,23 @@ def sync_checks(task, request): request.log.info("Adding new %s to the database." % check_name) fields = get_check_fields(check) request.db.add(MalwareCheck(**fields)) + + +@task(bind=True, ignore_result=True, acks_late=True) +def remove_verdicts(task, request, check_name): + check_ids = ( + request.db.query(MalwareCheck.id, MalwareCheck.version) + .filter(MalwareCheck.name == check_name) + .all() + ) + + for check_id, check_version in check_ids: + query = request.db.query(MalwareVerdict).filter( + MalwareVerdict.check_id == check_id + ) + num_verdicts = query.count() + request.log.info( + "Removing %d malware verdicts associated with %s version %d." + % (num_verdicts, check_name, check_version) + ) + query.delete(synchronize_session=False)