diff --git a/src/app.py b/src/app.py index a68e1f0d..0bfdb746 100644 --- a/src/app.py +++ b/src/app.py @@ -1,8 +1,4 @@ -import json import os -import random -import string -import time import uvicorn from datetime import datetime @@ -16,7 +12,8 @@ from lib.ursadb import UrsaDb from lib.yaraparse import parse_yara -from util import make_redis, mquery_version +from util import mquery_version +from db import Database, JobId import config from typing import Any, Callable, List, Union @@ -37,9 +34,9 @@ BackendStatusDatasetsSchema, ) -redis = make_redis() +db = Database() app = FastAPI() -db = UrsaDb(config.BACKEND) +ursa = UrsaDb(config.BACKEND) @app.middleware("http") @@ -58,9 +55,7 @@ async def add_headers(request: Request, call_next: Callable) -> Response: @app.get("/api/download") def download(job_id: str, ordinal: str, file_path: str) -> Any: - file_list = redis.lrange("meta:" + job_id, ordinal, ordinal) - - if not file_list or file_path != json.loads(file_list[0])["file"]: + if not db.job_contains(JobId(job_id), ordinal, file_path): raise NotFound("No such file in result set.") attach_name, ext = os.path.splitext(os.path.basename(file_path)) @@ -96,63 +91,31 @@ def query( for rule in rules ] - job_hash = "".join( - random.SystemRandom().choice(string.ascii_uppercase + string.digits) - for _ in range(12) + job = db.create_search_task( + rules[-1].name, + rules[-1].author, + data.raw_yara, + data.priority, + data.taint, ) - - job_obj = { - "status": "new", - "rule_name": rules[-1].name, - "rule_author": rules[-1].author, - "raw_yara": data.raw_yara, - "submitted": int(time.time()), - "priority": data.priority, - } - - if data.taint is not None: - job_obj["taint"] = data.taint - - redis.hmset("job:" + job_hash, job_obj) - redis.rpush("queue-search", job_hash) - - return QueryResponseSchema(query_hash=job_hash) + return QueryResponseSchema(query_hash=job.hash) @app.get("/api/matches/{hash}", response_model=MatchesSchema) def matches( hash: str, offset: int = Query(...), limit: int = Query(...) ) -> MatchesSchema: - p = redis.pipeline(transaction=False) - p.hgetall("job:" + hash) - p.lrange("meta:" + hash, offset, offset + limit - 1) - job, meta = p.execute() - return MatchesSchema(job=job, matches=[json.loads(m) for m in meta]) - - -def get_job(job_id: str) -> JobSchema: - job = redis.hgetall(job_id) - return JobSchema( - id=job_id[4:], - status=job.get("status", "ERROR"), - rule_name=job.get("rule_name", "ERROR"), - rule_author=job.get("rule_author", None), - raw_yara=job.get("raw_yara", "ERROR"), - submitted=job.get("submitted", 0), - priority=job.get("priority", "medium"), - files_processed=job.get("files_processed", 0), - total_files=job.get("total_files", 0), - ) + return db.get_job_matches(JobId(hash), offset, limit) @app.get("/api/job/{job_id}", response_model=JobSchema) def job_info(job_id: str) -> JobSchema: - return get_job(f"job:{job_id}") + return db.get_job(JobId(job_id)) @app.delete("/api/job/{job_id}", response_model=StatusSchema) def job_cancel(job_id: str) -> StatusSchema: - redis.hmset("job:" + job_id, {"status": "cancelled"}) + db.cancel_job(JobId(job_id)) return StatusSchema(status="ok") @@ -202,7 +165,7 @@ def user_jobs(name: str) -> List[JobSchema]: @app.get("/api/job", response_model=JobsSchema) def job_statuses() -> JobsSchema: - jobs = [get_job(j) for j in redis.keys("job:*")] + jobs = [db.get_job(job) for job in db.get_job_ids()] jobs = sorted(jobs, key=lambda j: j.submitted, reverse=True) return JobsSchema(jobs=jobs) @@ -210,7 +173,7 @@ def job_statuses() -> JobsSchema: @app.get("/api/backend", response_model=BackendStatusSchema) def backend_status() -> BackendStatusSchema: db_alive = True - status = db.status() + status = ursa.status() try: tasks = status.get("result", {}).get("tasks", []) ursadb_version = status.get("result", {}).get( @@ -236,7 +199,7 @@ def backend_status_datasets() -> BackendStatusDatasetsSchema: db_alive = True try: - datasets = db.topology().get("result", {}).get("datasets", {}) + datasets = ursa.topology().get("result", {}).get("datasets", {}) except Again: db_alive = False datasets = {} @@ -258,7 +221,7 @@ def serve_index_sub() -> FileResponse: @app.get("/api/compactall") def compact_all() -> StatusSchema: - redis.rpush("queue-commands", "compact all;") + db.run_command("compact all;") return StatusSchema(status="ok") diff --git a/src/daemon.py b/src/daemon.py index 8af7cd73..da6ebefd 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,24 +1,23 @@ #!/usr/bin/env python -import json import logging import time import yara # type: ignore from functools import lru_cache -import random from yara import SyntaxError import config from lib.ursadb import UrsaDb from lib.yaraparse import parse_yara, combine_rules -from util import make_redis, setup_logging -from typing import Any, Dict, List, Optional, Tuple +from util import setup_logging +from typing import Any, Dict, List +from db import Database, JobId, MatchInfo -redis = make_redis() -db = UrsaDb(config.BACKEND) +db = Database() +ursa = UrsaDb(config.BACKEND) @lru_cache(maxsize=32) -def compile_yara(job_hash: str) -> Any: - yara_rule = redis.hget("job:" + job_hash, "raw_yara") +def compile_yara(job: JobId) -> Any: + yara_rule = db.get_yara_by_job(job) logging.info("Compiling Yara") try: @@ -30,100 +29,70 @@ def compile_yara(job_hash: str) -> Any: return rule -def get_list_name(priority: str) -> str: - if priority == "low": - return "list-yara-low" - elif priority == "medium": - return "list-yara-medium" - else: - return "list-yara-high" - - def collect_expired_jobs() -> None: if config.JOB_EXPIRATION_MINUTES <= 0: return exp_time = int(60 * config.JOB_EXPIRATION_MINUTES) # conversion to seconds - job_hashes = [] - for job_hash in redis.keys("job:*"): - job_hashes.append(job_hash[4:]) - - for job in job_hashes: - redis.set("gc-lock", "locked", ex=60) - job_submitted_time = int(redis.hget("job:" + job, "submitted")) - if (int(time.time()) - job_submitted_time) >= exp_time: - redis.hset("job:{}".format(job), "status", "expired") - redis.delete("meta:{}".format(job)) + for job in db.get_job_ids(): + job_submission_time = db.get_job_submitted(job) + if (int(time.time()) - job_submission_time) >= exp_time: + db.expire_job(job) def process_task(queue: str, data: str) -> None: if queue == "queue-search": - job_hash = data - logging.info(f"New task: {queue}:{job_hash}") + job = JobId(data) + logging.info(f"New task: {queue}:{job.hash}") try: - execute_search(job_hash) + execute_search(job) except Exception as e: logging.exception("Failed to execute job.") - redis.hmset( - "job:" + job_hash, {"status": "failed", "error": str(e)} - ) + db.fail_job(None, job, str(e)) elif queue == "queue-commands": logging.info("Running a command: %s", data) - resp = db.execute_command(data) + resp = ursa.execute_command(data) logging.info(resp) -def get_random_job_by_priority() -> Tuple[Optional[str], str]: - yara_lists = ["list-yara-high", "list-yara-medium", "list-yara-low"] - for yara_list in yara_lists: - yara_jobs = redis.lrange(yara_list, 0, -1) - if yara_jobs: - return yara_list, random.choice(yara_jobs) - return None, "" - - def try_to_do_task() -> bool: - task_queues = ["queue-search", "queue-commands"] - task = None - for queue in task_queues: - task = redis.lpop(queue) - if task is not None: - data = task - process_task(queue, data) - return True + queue_and_task = db.get_task() + if queue_and_task is not None: + queue, task = queue_and_task + process_task(queue, task) + return True return False def try_to_do_search() -> bool: - yara_list, job_hash = get_random_job_by_priority() - if yara_list is None: + rnd_job = db.get_random_job_by_priority() + if rnd_job is None: return False - - job_id = "job:" + job_hash - job_data = redis.hgetall(job_id) + yara_list, job = rnd_job + job_data = db.get_job(job) try: BATCH_SIZE = 500 - pop_result = db.pop(job_data["iterator"], BATCH_SIZE) + if job_data.iterator is None: + raise RuntimeError(f"Job {job} has no iterator") + pop_result = ursa.pop(job_data.iterator, BATCH_SIZE) if pop_result.was_locked: return True if pop_result.files: - execute_yara(job_hash, pop_result.files) + execute_yara(job, pop_result.files) if pop_result.should_drop_iterator: logging.info( "Iterator %s exhausted, removing job %s", - job_data["iterator"], - job_hash, + job_data.iterator, + job, ) - redis.hset(job_id, "status", "done") - redis.lrem(yara_list, 0, job_hash) + db.finish_job(yara_list, job) except Exception as e: logging.exception("Failed to execute yara match.") - redis.hmset(job_id, {"status": "failed", "error": str(e)}) - redis.lrem(yara_list, 0, job_hash) + db.fail_job(yara_list, job, str(e)) return True @@ -133,7 +102,7 @@ def job_daemon() -> None: for extractor in config.METADATA_EXTRACTORS: logging.info("Plugin loaded: %s", extractor.__class__.__name__) - extractor.set_redis(redis) + extractor.set_redis(db.unsafe_get_redis()) logging.info("Daemon loaded, entering the main loop...") @@ -144,13 +113,13 @@ def job_daemon() -> None: if try_to_do_search(): continue - if redis.set("gc-lock", "locked", ex=60, nx=True): + if db.gc_lock(): collect_expired_jobs() time.sleep(5) -def update_metadata(job_hash: str, file_path: str, matches: List[str]) -> None: +def update_metadata(job: JobId, file_path: str, matches: List[str]) -> None: current_meta: Dict[str, Any] = {} for extractor in config.METADATA_EXTRACTORS: @@ -169,7 +138,7 @@ def update_metadata(job_hash: str, file_path: str, matches: List[str]) -> None: # we build local dictionary for each extractor, thus enforcing dependencies to be declared correctly local_meta.update(current_meta[dep]) - local_meta.update(job=job_hash) + local_meta.update(job=job.hash) current_meta[extr_name] = extractor.extract(file_path, local_meta) # flatten @@ -178,15 +147,12 @@ def update_metadata(job_hash: str, file_path: str, matches: List[str]) -> None: for v in current_meta.values(): flat_meta.update(v) - redis.rpush( - "meta:{}".format(job_hash), - json.dumps({"file": file_path, "meta": flat_meta, "matches": matches}), - ) + match = MatchInfo(file_path, flat_meta, matches) + db.add_match(job, match) -def execute_yara(job_hash: str, files: List[str]) -> None: - job_id = f"job:{job_hash}" - if redis.hget(job_id, "status") in [ +def execute_yara(job: JobId, files: List[str]) -> None: + if db.get_job_status(job) in [ "cancelled", "failed", ]: @@ -195,38 +161,36 @@ def execute_yara(job_hash: str, files: List[str]) -> None: if len(files) == 0: return - rule = compile_yara(job_hash) + rule = compile_yara(job) for sample in files: try: matches = rule.match(sample) if matches: - update_metadata(job_hash, sample, [r.rule for r in matches]) + update_metadata(job, sample, [r.rule for r in matches]) except yara.Error: logging.exception(f"Yara failed to check file {sample}") except FileNotFoundError: logging.exception(f"Failed to open file for yara check: {sample}") - redis.hincrby(job_id, "files_processed", len(files)) + db.update_job(job, len(files)) -def execute_search(job_hash: str) -> None: +def execute_search(job_id: JobId) -> None: logging.info("Parsing...") - job_id = "job:" + job_hash - job = redis.hgetall(job_id) - yara_rule = job["raw_yara"] + job = db.get_job(job_id) + yara_rule = job.raw_yara - redis.hmset(job_id, {"status": "parsing", "timestamp": time.time()}) + db.set_job_to_parsing(job_id) rules = parse_yara(yara_rule) parsed = combine_rules(rules) - redis.hmset(job_id, {"status": "querying", "timestamp": time.time()}) + db.set_job_to_querying(job_id) logging.info("Querying backend...") - taint = job.get("taint", None) - result = db.query(parsed.query, taint) + result = ursa.query(parsed.query, job.taint) if "error" in result: raise RuntimeError(result["error"]) @@ -234,21 +198,12 @@ def execute_search(job_hash: str) -> None: iterator = result["iterator"] logging.info(f"Iterator contains {file_count} files") - redis.hmset( - job_id, - { - "status": "processing", - "iterator": iterator, - "files_processed": 0, - "total_files": file_count, - }, - ) + db.set_job_to_processing(job_id, iterator, file_count) if file_count > 0: - list_name = get_list_name(job["priority"]) - redis.lpush(list_name, job_hash) + db.push_job_to_queue(job) else: - redis.hset(job_id, "status", "done") + db.finish_job(None, job_id) if __name__ == "__main__": diff --git a/src/db.py b/src/db.py new file mode 100644 index 00000000..99b8dcd2 --- /dev/null +++ b/src/db.py @@ -0,0 +1,254 @@ +from typing import List, Tuple, Optional, Dict, Any +from schema import JobSchema, MatchesSchema, StorageSchema +from time import time +import json +import random +import string +import config +from datetime import datetime +from redis import StrictRedis + + +def make_redis() -> StrictRedis: + return StrictRedis( + host=config.REDIS_HOST, port=config.REDIS_PORT, decode_responses=True + ) + + +def get_list_name(priority: str) -> str: + if priority == "low": + return "list-yara-low" + elif priority == "medium": + return "list-yara-medium" + else: + return "list-yara-high" + + +class JobId: + """ Represents a unique job ID in redis. Looks like this: `job:IU32AD3` """ + + def __init__(self, key: str) -> None: + """ Creates a new JobId object. Can take both key and raw hash. """ + if not key.startswith("job:"): + key = f"job:{key}" + self.key = key + self.hash = key[4:] + + @property + def meta_key(self) -> str: + """ Every job has exactly one related meta key""" + return f"meta:{self.hash}" + + def __repr__(self) -> str: + return self.key + + +class JobQueue: + """ Represents one of the available job queues """ + + def __init__(self, name: str) -> None: + self.name = name + + @classmethod + def available(cls) -> List["JobQueue"]: + names = ["list-yara-high", "list-yara-medium", "list-yara-low"] + return [cls(name) for name in names] + + def __repr__(self) -> str: + return f"queue:{self.name}" + + +class MatchInfo: + """ Represents information about a single match """ + + def __init__( + self, file: str, meta: Dict[str, Any], matches: List[str] + ) -> None: + self.file = file + self.meta = meta + self.matches = matches + + def to_json(self) -> str: + """ Converts match info to json """ + return json.dumps( + {"file": self.file, "meta": self.meta, "matches": self.matches} + ) + + +class Database: + def __init__(self) -> None: + self.redis = make_redis() + + def get_yara_by_job(self, job: JobId) -> str: + """ Gets yara rule associated with job """ + return self.redis.hget(job.key, "raw_yara") + + def get_job_submitted(self, job: JobId) -> int: + """ Gets submission date of the job """ + return int(self.redis.hget(job.key, "submitted")) + + def get_job_status(self, job: JobId) -> str: + """ Gets status of the specified job """ + return self.redis.hget(job.key, "status") + + def get_job_ids(self) -> List[JobId]: + """ Gets IDs of all jobs in the database """ + return [JobId(key) for key in self.redis.keys("job:*")] + + def expire_job(self, job: JobId) -> None: + """ Sets the job status to expired, and removes it from the db """ + self.redis.hset(job.key, "status", "expired") + self.redis.delete(job.meta_key) + + def fail_job( + self, queue: Optional[JobQueue], job: JobId, message: str + ) -> None: + """ Sets the job status to failed, and removes it from job queues """ + self.redis.hmset(job.key, {"status": "failed", "error": message}) + if queue: + self.redis.lrem(queue.name, 0, job.hash) + + def cancel_job(self, job: JobId) -> None: + """ Sets the job status to cancelled """ + self.redis.hmset(job.key, {"status": "cancelled"}) + + def finish_job(self, queue: Optional[JobQueue], job: JobId) -> None: + """ Sets the job status to done, and removes it from job queues """ + self.redis.hset(job.key, "status", "done") + if queue: + self.redis.lrem(queue.name, 0, job.hash) + + def set_job_to_processing( + self, job: JobId, iterator: str, file_count: int + ) -> None: + self.redis.hmset( + job.key, + { + "status": "processing", + "iterator": iterator, + "files_processed": 0, + "total_files": file_count, + }, + ) + + def update_job(self, job: JobId, files_processed: int) -> None: + self.redis.hincrby(job.key, "files_processed", files_processed) + + def set_job_to_parsing(self, job: JobId) -> None: + """ Sets the job status to parsing """ + self.redis.hmset(job.key, {"status": "parsing", "timestamp": time()}) + + def set_job_to_querying(self, job: JobId) -> None: + """ Sets the job status to querying """ + self.redis.hmset(job.key, {"status": "querying", "timestamp": time()}) + + def gc_lock(self) -> bool: + """ Tries to get a GC lock,and returns ture if succeeded """ + return bool(self.redis.set("gc-lock", "locked", ex=60, nx=True)) + + def push_job_to_queue(self, job: JobSchema) -> None: + list_name = get_list_name(job.priority) + self.redis.lpush(list_name, job.id) + + def get_random_job_by_priority(self) -> Optional[Tuple[JobQueue, JobId]]: + """ Tries to get a random job along with its queue """ + for queue in JobQueue.available(): + yara_jobs = self.redis.lrange(queue.name, 0, -1) + if yara_jobs: + return queue, JobId(random.choice(yara_jobs)) + return None + + def get_job(self, job: JobId) -> JobSchema: + data = self.redis.hgetall(job.key) + return JobSchema( + id=job.hash, + status=data.get("status", "ERROR"), + rule_name=data.get("rule_name", "ERROR"), + rule_author=data.get("rule_author", None), + raw_yara=data.get("raw_yara", "ERROR"), + submitted=data.get("submitted", 0), + priority=data.get("priority", "medium"), + files_processed=data.get("files_processed", 0), + total_files=data.get("total_files", 0), + iterator=data.get("iterator", None), + taint=data.get("taint", None), + ) + + def add_match(self, job: JobId, match: MatchInfo) -> None: + self.redis.rpush(job.meta_key, match.to_json()) + + def job_contains(self, job: JobId, ordinal: str, file_path: str) -> bool: + file_list = self.redis.lrange(job.meta_key, ordinal, ordinal) + return file_list and file_path == json.loads(file_list[0])["file"] + + def create_search_task( + self, + rule_name: str, + rule_author: str, + raw_yara: str, + priority: Optional[str], + taint: Optional[str], + ) -> JobId: + job = JobId( + "".join( + random.SystemRandom().choice( + string.ascii_uppercase + string.digits + ) + for _ in range(12) + ) + ) + job_obj = { + "status": "new", + "rule_name": rule_name, + "rule_author": rule_author, + "raw_yara": raw_yara, + "submitted": int(time()), + "priority": priority, + } + if taint is not None: + job_obj["taint"] = taint + + self.redis.hmset(job.key, job_obj) + self.redis.rpush("queue-search", job.hash) + return job + + def get_job_matches( + self, job: JobId, offset: int, limit: int + ) -> MatchesSchema: + p = self.redis.pipeline(transaction=False) + p.hgetall(job.key) + p.lrange("meta:" + job.hash, offset, offset + limit - 1) + job, meta = p.execute() + return MatchesSchema(job=job, matches=[json.loads(m) for m in meta]) + + def run_command(self, command: str) -> None: + self.redis.rpush("queue-commands", command) + + def get_task(self) -> Optional[Tuple[str, str]]: + task_queues = ["queue-search", "queue-commands"] + for queue in task_queues: + task = self.redis.lpop(queue) + if task is not None: + return queue, task + return None + + def unsafe_get_redis(self) -> StrictRedis: + return self.redis + + def get_storage(self, storage_id: str) -> StorageSchema: + data = self.redis.hgetall(storage_id) + return StorageSchema( + id=storage_id, + name=data["name"], + path=data["path"], + indexing_job_id=None, + last_update=datetime.fromtimestamp(data["timestamp"]), + taints=data["taints"], + enabled=data["enabled"], + ) + + def get_storages(self) -> List[StorageSchema]: + return [ + self.get_storage(storage_id) + for storage_id in self.redis.keys("storage:*") + ] diff --git a/src/lib/ursadb.py b/src/lib/ursadb.py index fb382cc8..a7d280ca 100644 --- a/src/lib/ursadb.py +++ b/src/lib/ursadb.py @@ -1,7 +1,7 @@ import json import time import zmq # type: ignore -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional Json = Dict[str, Any] @@ -32,7 +32,7 @@ def make_socket(self, recv_timeout: int = 2000) -> zmq.Context: socket.connect(self.backend) return socket - def query(self, query: str, taint: str) -> Json: + def query(self, query: str, taint: Optional[str]) -> Json: socket = self.make_socket(recv_timeout=-1) start = time.clock() diff --git a/src/schema.py b/src/schema.py index 7a98f085..0a3bba83 100644 --- a/src/schema.py +++ b/src/schema.py @@ -1,7 +1,6 @@ from enum import Enum from typing import List, Dict, Optional from datetime import datetime - from pydantic import BaseModel @@ -15,6 +14,8 @@ class JobSchema(BaseModel): priority: str files_processed: int total_files: int + iterator: Optional[str] + taint: Optional[str] class JobsSchema(BaseModel): diff --git a/src/util.py b/src/util.py index 2187d7fb..2ca8eca4 100644 --- a/src/util.py +++ b/src/util.py @@ -1,10 +1,5 @@ import logging -from redis import StrictRedis - -import config -from typing import Any - LOG_FORMAT = "[%(asctime)s][%(levelname)s] %(message)s" LOG_DATEFMT = "%d/%m/%Y %H:%M:%S" @@ -16,11 +11,5 @@ def setup_logging() -> None: ) -def make_redis() -> Any: - return StrictRedis( - host=config.REDIS_HOST, port=config.REDIS_PORT, decode_responses=True - ) - - def mquery_version(): return "1.1.0"