From fb84c31e46f4add6f10e4de52e4db5de2e663a85 Mon Sep 17 00:00:00 2001 From: Cristina Date: Tue, 7 Jan 2020 12:07:32 -0800 Subject: [PATCH] Refactor MalwareCheckBase. Fixes #7091. (#7196) * Refactor MalwareCheckBase. Fixes #7091. Add Foreign Keys in MalwareVerdicts for other types of objects (Releases, Projects). * Change verdict dict to kwargs. --- warehouse/malware/checks/base.py | 24 ++++++++++++------- warehouse/malware/checks/example.py | 12 +++------- warehouse/malware/models.py | 6 ++++- ...1ff3d24c22_add_malware_detection_tables.py | 6 ++++- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py index b5102bfe6f71..6c94937abbde 100644 --- a/warehouse/malware/checks/base.py +++ b/warehouse/malware/checks/base.py @@ -10,18 +10,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict class MalwareCheckBase: def __init__(self, db): self.db = db self._name = self.__class__.__name__ - self._load_check() + self._load_check_id() + self._verdicts = [] + + def add_verdict(self, **kwargs): + self._verdicts.append(MalwareVerdict(check_id=self.id, **kwargs)) def run(self, obj_id): """ - Executes the check. + Runs the check and inserts returned verdicts. + """ + self.scan(obj_id) + self.db.add_all(self._verdicts) + + def scan(self, obj_id): + """ + Scans the object and returns a verdict. """ def backfill(self, sample=1): @@ -31,12 +42,7 @@ def backfill(self, sample=1): backfill on the entire corpus. """ - def update(self): - """ - Update the check definition in the database. - """ - - def _load_check(self): + def _load_check_id(self): self.id = ( self.db.query(MalwareCheck.id) .filter(MalwareCheck.name == self._name) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py index 519edecfd4df..22b91906ffa3 100644 --- a/warehouse/malware/checks/example.py +++ b/warehouse/malware/checks/example.py @@ -11,11 +11,7 @@ # limitations under the License. from warehouse.malware.checks.base import MalwareCheckBase -from warehouse.malware.models import ( - MalwareVerdict, - VerdictClassification, - VerdictConfidence, -) +from warehouse.malware.models import VerdictClassification, VerdictConfidence class ExampleCheck(MalwareCheckBase): @@ -30,12 +26,10 @@ class ExampleCheck(MalwareCheckBase): def __init__(self, db): super().__init__(db) - def run(self, file_id): - verdict = MalwareVerdict( - check_id=self.id, + def scan(self, file_id): + self.add_verdict( file_id=file_id, classification=VerdictClassification.benign, confidence=VerdictConfidence.High, message="Nothing to see here!", ) - self.db.add(verdict) diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 257e7bfa2bd5..3e9aa388a701 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -115,7 +115,9 @@ class MalwareVerdict(db.Model): nullable=False, index=True, ) - file_id = Column(ForeignKey("release_files.id"), nullable=False) + file_id = Column(ForeignKey("release_files.id"), nullable=True) + release_id = Column(ForeignKey("releases.id"), nullable=True) + project_id = Column(ForeignKey("projects.id"), nullable=True) classification = Column( Enum(VerdictClassification, values_callable=lambda x: [e.value for e in x]), nullable=False, @@ -135,3 +137,5 @@ class MalwareVerdict(db.Model): check = orm.relationship("MalwareCheck", foreign_keys=[check_id], lazy=True) release_file = orm.relationship("File", foreign_keys=[file_id], lazy=True) + release = orm.relationship("Release", foreign_keys=[release_id], lazy=True) + project = orm.relationship("Project", foreign_keys=[project_id], lazy=True) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index e74a9ddabe94..6e23aeb243f8 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -77,7 +77,9 @@ def upgrade(): "run_date", sa.DateTime(), server_default=sa.text("now()"), nullable=False ), sa.Column("check_id", postgresql.UUID(as_uuid=True), nullable=False), - sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("file_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("release_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("classification", VerdictClassifications, nullable=False,), sa.Column("confidence", VerdictConfidences, nullable=False,), sa.Column("message", sa.Text(), nullable=True), @@ -94,6 +96,8 @@ def upgrade(): ["check_id"], ["malware_checks.id"], onupdate="CASCADE", ondelete="CASCADE" ), sa.ForeignKeyConstraint(["file_id"], ["release_files.id"]), + sa.ForeignKeyConstraint(["release_id"], ["releases.id"]), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"]), sa.PrimaryKeyConstraint("id"), ) op.create_index(