Skip to content

Commit

Permalink
fix: moved logic to more suitable classes and files
Browse files Browse the repository at this point in the history
  • Loading branch information
mickol34 committed Oct 14, 2024
1 parent bab9068 commit 4074fe3
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 43 deletions.
14 changes: 14 additions & 0 deletions src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .models.job import Job
from .models.jobagent import JobAgent
from .models.match import Match
from .models.queryresult import QueryResult
from .schema import MatchesSchema, ConfigSchema
from .config import app_config

Expand Down Expand Up @@ -109,6 +110,19 @@ def add_match(self, job: JobId, match: Match) -> None:
session.add(match)
session.commit()

def add_queryresult(self, job_id: int | None, files: List[str]) -> None:
with self.session() as session:
obj = QueryResult(job_id=job_id, files=files)
session.add(obj)
session.commit()

def remove_queryresult(self, job_id: int | None) -> None:
with self.session() as session:
session.query(QueryResult).where(
QueryResult.job_id == job_id
).delete()
session.commit()

def job_contains(self, job: JobId, ordinal: int, file_path: str) -> bool:
"""Make sure that the file path is in the job results."""
with self.session() as session:
Expand Down
12 changes: 1 addition & 11 deletions src/lib/ursadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
import zmq # type: ignore
from typing import Dict, Any, List, Optional

from config import app_config
from models.queryresult import QueryResult
from db import Database, JobId


Json = Dict[str, Any]

Expand Down Expand Up @@ -41,7 +37,6 @@ def __str__(self) -> str:
class UrsaDb:
def __init__(self, backend: str) -> None:
self.backend = backend
self.redis_db = Database(app_config.redis.host, app_config.redis.port)

def __execute(self, command: str, recv_timeout: int = 2000) -> Json:
context = zmq.Context()
Expand All @@ -58,7 +53,6 @@ def __execute(self, command: str, recv_timeout: int = 2000) -> Json:
def query(
self,
query: str,
job_id: JobId,
taints: List[str] | None = None,
dataset: Optional[str] = None,
) -> Json:
Expand All @@ -79,13 +73,9 @@ def query(
error = res.get("error", {}).get("message", "(no message)")
return {"error": f"ursadb failed: {error}"}

with self.redis_db.session() as session:
obj = QueryResult(job_id=job_id, files=res['result']['files'])
session.add(obj)
session.commit()

return {
"time": (end - start),
"files": res["result"]["files"],
}

def pop(self, iterator: str, count: int) -> PopResult:
Expand Down
7 changes: 4 additions & 3 deletions src/models/queryresult.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlmodel import Field, SQLModel, ARRAY, Column, String
from typing import List
from typing import List, Union


class QueryResult(SQLModel, table=True):
job_id: str = Field(foreign_key="job.internal_id", primary_key=True)
files: List[str] = Field(sa_column=Column(ARRAY(String)))
id: Union[int, None] = Field(default=None, primary_key=True)
job_id: Union[int, None] = Field(foreign_key="job.internal_id")
files: List[str] = Field(sa_column=Column(ARRAY(String)))
49 changes: 20 additions & 29 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
from rq import get_current_job, Queue # type: ignore
from redis import Redis
from contextlib import contextmanager
from sqlalchemy import delete, update
from sqlmodel import select
import yara # type: ignore
from itertools import accumulate

from .db import Database, JobId
from .util import make_sha256_tag
from .config import app_config
from .plugins import PluginManager
from .models.job import Job
from .models.match import Match
from .models.queryresult import QueryResult
from .lib.yaraparse import parse_yara, combine_rules
from .lib.ursadb import Json, UrsaDb
from .metadata import Metadata
Expand Down Expand Up @@ -239,13 +237,14 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
logging.info("Job was cancelled, returning...")
return

result = agent.ursa.query(ursadb_query, job_id, job.taints, dataset_id)
result = agent.ursa.query(ursadb_query, job.taints, dataset_id)
if "error" in result:
raise RuntimeError(result["error"])

with agent.db.session() as session:
result = session.exec(select(QueryResult).where(QueryResult.job_id == job_id)).one()
file_count = len(result.files)
files = result["files"]
agent.db.add_queryresult(job.internal_id, files)

file_count = len(files)

total_files = agent.db.update_job_files(job_id, file_count)
if job.files_limit and total_files > job.files_limit:
Expand All @@ -254,44 +253,36 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
"Try a more precise query."
)

batches = __get_batch_sizes(file_count)
# add len(batches) new tasks, -1 to account for this task
agent.add_tasks_in_progress(job, len(batches) - 1)
batch_sizes = __get_batch_sizes(file_count)
# add len(batch_sizes) new tasks, -1 to account for this task
agent.add_tasks_in_progress(job, len(batch_sizes) - 1)

for batch in batches:
batched_files = (
files[batch_end - batch_size : batch_end]
for batch_end, batch_size in zip(
accumulate(batch_sizes), batch_sizes
)
)

for batch_files in batched_files:
agent.queue.enqueue(
run_yara_batch,
job_id,
result,
batch,
batch_files,
job_timeout=app_config.rq.job_timeout,
)

agent.db.dataset_query_done(job_id)
agent.db.remove_queryresult(job.internal_id)


def run_yara_batch(job_id: JobId, result: QueryResult, batch_size: int) -> None:
def run_yara_batch(job_id: JobId, batch_files: List[str]) -> None:
"""Actually scans files, and updates a database with the results."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
logging.info("Job was cancelled, returning...")
return

## 1. get batch_size first files from result
batch_files = result.files[0:batch_size]

## 2. remove batch files from result
with agent.db.session() as session:
session.execute(
update(QueryResult).where(QueryResult.job_id == result.job_id).values(files=result.files[batch_size+1:])
)

## 3. if result has no files, delete
session.execute(
delete(QueryResult).where(QueryResult.job_id == job_id).where(QueryResult.files == [])
)
session.commit()

agent.execute_yara(job, batch_files)
agent.add_tasks_in_progress(job, -1)

0 comments on commit 4074fe3

Please sign in to comment.