Skip to content

Commit

Permalink
Refactor MalwareCheckBase. Fixes #7091. (#7196)
Browse files Browse the repository at this point in the history
* Refactor MalwareCheckBase. Fixes #7091.

Add Foreign Keys in MalwareVerdicts for other types of objects
(Releases, Projects).

* Change verdict dict to kwargs.
  • Loading branch information
xmunoz authored and ewdurbin committed Jan 17, 2020
1 parent 047f5ac commit fb84c31
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
24 changes: 15 additions & 9 deletions warehouse/malware/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from warehouse.malware.models import MalwareCheck, MalwareCheckState
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict


class MalwareCheckBase:
def __init__(self, db):
self.db = db
self._name = self.__class__.__name__
self._load_check()
self._load_check_id()
self._verdicts = []

def add_verdict(self, **kwargs):
self._verdicts.append(MalwareVerdict(check_id=self.id, **kwargs))

def run(self, obj_id):
"""
Executes the check.
Runs the check and inserts returned verdicts.
"""
self.scan(obj_id)
self.db.add_all(self._verdicts)

def scan(self, obj_id):
"""
Scans the object and returns a verdict.
"""

def backfill(self, sample=1):
Expand All @@ -31,12 +42,7 @@ def backfill(self, sample=1):
backfill on the entire corpus.
"""

def update(self):
"""
Update the check definition in the database.
"""

def _load_check(self):
def _load_check_id(self):
self.id = (
self.db.query(MalwareCheck.id)
.filter(MalwareCheck.name == self._name)
Expand Down
12 changes: 3 additions & 9 deletions warehouse/malware/checks/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
# limitations under the License.

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.models import (
MalwareVerdict,
VerdictClassification,
VerdictConfidence,
)
from warehouse.malware.models import VerdictClassification, VerdictConfidence


class ExampleCheck(MalwareCheckBase):
Expand All @@ -30,12 +26,10 @@ class ExampleCheck(MalwareCheckBase):
def __init__(self, db):
super().__init__(db)

def run(self, file_id):
verdict = MalwareVerdict(
check_id=self.id,
def scan(self, file_id):
self.add_verdict(
file_id=file_id,
classification=VerdictClassification.benign,
confidence=VerdictConfidence.High,
message="Nothing to see here!",
)
self.db.add(verdict)
6 changes: 5 additions & 1 deletion warehouse/malware/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ class MalwareVerdict(db.Model):
nullable=False,
index=True,
)
file_id = Column(ForeignKey("release_files.id"), nullable=False)
file_id = Column(ForeignKey("release_files.id"), nullable=True)
release_id = Column(ForeignKey("releases.id"), nullable=True)
project_id = Column(ForeignKey("projects.id"), nullable=True)
classification = Column(
Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]),
nullable=False,
Expand All @@ -135,3 +137,5 @@ class MalwareVerdict(db.Model):

check = orm.relationship("MalwareCheck", foreign_keys=[check_id], lazy=True)
release_file = orm.relationship("File", foreign_keys=[file_id], lazy=True)
release = orm.relationship("Release", foreign_keys=[release_id], lazy=True)
project = orm.relationship("Project", foreign_keys=[project_id], lazy=True)
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def upgrade():
"run_date", sa.DateTime(), server_default=sa.text("now()"), nullable=False
),
sa.Column("check_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("release_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("classification", VerdictClassifications, nullable=False,),
sa.Column("confidence", VerdictConfidences, nullable=False,),
sa.Column("message", sa.Text(), nullable=True),
Expand All @@ -94,6 +96,8 @@ def upgrade():
["check_id"], ["malware_checks.id"], onupdate="CASCADE", ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["file_id"], ["release_files.id"]),
sa.ForeignKeyConstraint(["release_id"], ["releases.id"]),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
Expand Down

0 comments on commit fb84c31

Please sign in to comment.