diff --git a/src/db.py b/src/db.py index 47036d0a..b7218bf8 100644 --- a/src/db.py +++ b/src/db.py @@ -22,6 +22,7 @@ from .models.job import Job from .models.jobagent import JobAgent from .models.match import Match +from .models.queryresult import QueryResult from .schema import MatchesSchema, ConfigSchema from .config import app_config @@ -109,6 +110,19 @@ def add_match(self, job: JobId, match: Match) -> None: session.add(match) session.commit() + def add_queryresult(self, job_id: int | None, files: List[str]) -> None: + with self.session() as session: + obj = QueryResult(job_id=job_id, files=files) + session.add(obj) + session.commit() + + def remove_queryresult(self, job_id: int | None) -> None: + with self.session() as session: + session.query(QueryResult).where( + QueryResult.job_id == job_id + ).delete() + session.commit() + def job_contains(self, job: JobId, ordinal: int, file_path: str) -> bool: """Make sure that the file path is in the job results.""" with self.session() as session: diff --git a/src/lib/ursadb.py b/src/lib/ursadb.py index 4c537bf9..2b0b4fbf 100644 --- a/src/lib/ursadb.py +++ b/src/lib/ursadb.py @@ -3,10 +3,6 @@ import zmq # type: ignore from typing import Dict, Any, List, Optional -from config import app_config -from models.queryresult import QueryResult -from db import Database, JobId - Json = Dict[str, Any] @@ -41,7 +37,6 @@ def __str__(self) -> str: class UrsaDb: def __init__(self, backend: str) -> None: self.backend = backend - self.redis_db = Database(app_config.redis.host, app_config.redis.port) def __execute(self, command: str, recv_timeout: int = 2000) -> Json: context = zmq.Context() @@ -58,7 +53,6 @@ def __execute(self, command: str, recv_timeout: int = 2000) -> Json: def query( self, query: str, - job_id: JobId, taints: List[str] | None = None, dataset: Optional[str] = None, ) -> Json: @@ -79,13 +73,9 @@ def query( error = res.get("error", {}).get("message", "(no message)") return {"error": f"ursadb failed: {error}"} - with self.redis_db.session() as session: - obj = QueryResult(job_id=job_id, files=res['result']['files']) - session.add(obj) - session.commit() - return { "time": (end - start), + "files": res["result"]["files"], } def pop(self, iterator: str, count: int) -> PopResult: diff --git a/src/models/queryresult.py b/src/models/queryresult.py index 60f90f21..76b94438 100644 --- a/src/models/queryresult.py +++ b/src/models/queryresult.py @@ -1,7 +1,8 @@ from sqlmodel import Field, SQLModel, ARRAY, Column, String -from typing import List +from typing import List, Union class QueryResult(SQLModel, table=True): - job_id: str = Field(foreign_key="job.internal_id", primary_key=True) - files: List[str] = Field(sa_column=Column(ARRAY(String))) + id: Union[int, None] = Field(default=None, primary_key=True) + job_id: Union[int, None] = Field(foreign_key="job.internal_id") + files: List[str] = Field(sa_column=Column(ARRAY(String))) diff --git a/src/tasks.py b/src/tasks.py index eea633f2..dd81b426 100644 --- a/src/tasks.py +++ b/src/tasks.py @@ -3,9 +3,8 @@ from rq import get_current_job, Queue # type: ignore from redis import Redis from contextlib import contextmanager -from sqlalchemy import delete, update -from sqlmodel import select import yara # type: ignore +from itertools import accumulate from .db import Database, JobId from .util import make_sha256_tag @@ -13,7 +12,6 @@ from .plugins import PluginManager from .models.job import Job from .models.match import Match -from .models.queryresult import QueryResult from .lib.yaraparse import parse_yara, combine_rules from .lib.ursadb import Json, UrsaDb from .metadata import Metadata @@ -239,13 +237,14 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: logging.info("Job was cancelled, returning...") return - result = agent.ursa.query(ursadb_query, job_id, job.taints, dataset_id) + result = agent.ursa.query(ursadb_query, job.taints, dataset_id) if "error" in result: raise RuntimeError(result["error"]) - with agent.db.session() as session: - result = session.exec(select(QueryResult).where(QueryResult.job_id == job_id)).one() - file_count = len(result.files) + files = result["files"] + agent.db.add_queryresult(job.internal_id, files) + + file_count = len(files) total_files = agent.db.update_job_files(job_id, file_count) if job.files_limit and total_files > job.files_limit: @@ -254,23 +253,30 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None: "Try a more precise query." ) - batches = __get_batch_sizes(file_count) - # add len(batches) new tasks, -1 to account for this task - agent.add_tasks_in_progress(job, len(batches) - 1) + batch_sizes = __get_batch_sizes(file_count) + # add len(batch_sizes) new tasks, -1 to account for this task + agent.add_tasks_in_progress(job, len(batch_sizes) - 1) - for batch in batches: + batched_files = ( + files[batch_end - batch_size : batch_end] + for batch_end, batch_size in zip( + accumulate(batch_sizes), batch_sizes + ) + ) + + for batch_files in batched_files: agent.queue.enqueue( run_yara_batch, job_id, - result, - batch, + batch_files, job_timeout=app_config.rq.job_timeout, ) agent.db.dataset_query_done(job_id) + agent.db.remove_queryresult(job.internal_id) -def run_yara_batch(job_id: JobId, result: QueryResult, batch_size: int) -> None: +def run_yara_batch(job_id: JobId, batch_files: List[str]) -> None: """Actually scans files, and updates a database with the results.""" with job_context(job_id) as agent: job = agent.db.get_job(job_id) @@ -278,20 +284,5 @@ def run_yara_batch(job_id: JobId, result: QueryResult, batch_size: int) -> None: logging.info("Job was cancelled, returning...") return - ## 1. get batch_size first files from result - batch_files = result.files[0:batch_size] - - ## 2. remove batch files from result - with agent.db.session() as session: - session.execute( - update(QueryResult).where(QueryResult.job_id == result.job_id).values(files=result.files[batch_size+1:]) - ) - - ## 3. if result has no files, delete - session.execute( - delete(QueryResult).where(QueryResult.job_id == job_id).where(QueryResult.files == []) - ) - session.commit() - agent.execute_yara(job, batch_files) agent.add_tasks_in_progress(job, -1)