Skip to content

Commit

Permalink
Refactor testing logic #7098 (#7257)
Browse files Browse the repository at this point in the history
- Add `schedule` field to MalwareCheck model #7096
- Move ExampleCheck into tests/common/ to remove test dependency from
prod code
- Rename functions and classes to differentiate between "hooked" and
"scheduled" checks
  • Loading branch information
xmunoz authored and ewdurbin committed Jan 27, 2020
1 parent 148fcea commit 7616c91
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 119 deletions.
14 changes: 14 additions & 0 deletions tests/common/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .hooked import ExampleHookedCheck # noqa
from .scheduled import ExampleScheduledCheck # noqa
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
from warehouse.malware.models import VerdictClassification, VerdictConfidence


class ExampleCheck(MalwareCheckBase):
class ExampleHookedCheck(MalwareCheckBase):

version = 1
short_description = "An example hook-based check"
long_description = """The purpose of this check is to demonstrate the \
implementation of a hook-based check. This check will generate verdicts if enabled."""
long_description = "The purpose of this check is to test the \
implementation of a hook-based check. This check will generate verdicts if enabled."
check_type = "event_hook"
hooked_object = "File"

def __init__(self, db):
super().__init__(db)

def scan(self, file_id):
def scan(self, file_id=None):
self.add_verdict(
file_id=file_id,
classification=VerdictClassification.benign,
Expand Down
37 changes: 37 additions & 0 deletions tests/common/checks/scheduled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.models import VerdictClassification, VerdictConfidence
from warehouse.packaging.models import Project


class ExampleScheduledCheck(MalwareCheckBase):

version = 1
short_description = "An example scheduled check"
long_description = "The purpose of this check is to test the \
implementation of a scheduled check. This check will generate verdicts if enabled."
check_type = "scheduled"
schedule = {"minute": "0", "hour": "*/8"}

def __init__(self, db):
super().__init__(db)

def scan(self):
project = self.db.query(Project).first()
self.add_verdict(
project_id=project.id,
classification=VerdictClassification.benign,
confidence=VerdictConfidence.High,
message="Nothing to see here!",
)
1 change: 1 addition & 0 deletions tests/common/db/malware.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Meta:
long_description = factory.fuzzy.FuzzyText(length=300)
check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType))
hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType))
schedule = {"minute": "*/10"}
state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState))
created = factory.fuzzy.FuzzyNaiveDateTime(
datetime.datetime.utcnow() - datetime.timedelta(days=7)
Expand Down
19 changes: 14 additions & 5 deletions tests/unit/malware/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,35 @@

import inspect

import warehouse.malware.checks as checks
import pytest

import warehouse.malware.checks as prod_checks

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.utils import get_check_fields

from ...common import checks as test_checks


def test_checks_subclass_base():
checks_from_module = inspect.getmembers(checks, inspect.isclass)
prod_checks_from_module = inspect.getmembers(prod_checks, inspect.isclass)
test_checks_from_module = inspect.getmembers(test_checks, inspect.isclass)
all_checks = prod_checks_from_module + test_checks_from_module

subclasses_of_malware_base = {
cls.__name__: cls for cls in MalwareCheckBase.__subclasses__()
}

assert len(checks_from_module) == len(subclasses_of_malware_base)
assert len(all_checks) == len(subclasses_of_malware_base)

for check_name, check in checks_from_module:
for check_name, check in all_checks:
assert subclasses_of_malware_base[check_name] == check


def test_checks_fields():
@pytest.mark.parametrize(
("checks"), [prod_checks, test_checks],
)
def test_checks_fields(checks):
checks_from_module = inspect.getmembers(checks, inspect.isclass)

for check_name, check in checks_from_module:
Expand Down
25 changes: 14 additions & 11 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService

from ...common import checks as test_checks
from ...common.db.accounts import UserFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_determine_malware_checks_no_checks(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
return defaultdict(list)

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -39,13 +40,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_nothing_new(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -58,13 +59,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_unsupported_object(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

user = UserFactory.create()

Expand All @@ -75,13 +76,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_file_only(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand All @@ -95,13 +96,13 @@ def get_enabled_checks(session):


def test_determine_malware_checks_file_and_release(monkeypatch, db_request):
def get_enabled_checks(session):
def get_enabled_hooked_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
Expand Down Expand Up @@ -149,7 +150,9 @@ def test_enqueue_malware_checks_no_checks(app_config):
assert "warehouse.malware.checks" not in session.info


def test_includeme():
def test_includeme(monkeypatch):
monkeypatch.setattr(malware, "checks", test_checks)

malware_check_class = pretend.stub(
create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub())
)
Expand Down
Loading

0 comments on commit 7616c91

Please sign in to comment.