Skip to content

Commit

Permalink
[omm] Hash api implementation first draft (#1355)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored Sep 12, 2023
1 parent 0eca9d6 commit cc04f20
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 26 deletions.
6 changes: 5 additions & 1 deletion open-media-match/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ dependencies = [
"flask_sqlalchemy",
"flask_migrate",
"psycopg2",
"threatexchange",
]

[project.optional-dependencies]
all = [
"mypy",
"black",
"pytest"
"pytest",
"types-Flask-Migrate",
"types-requests",

]

test = [ "pytest" ]
Expand Down
11 changes: 10 additions & 1 deletion open-media-match/src/OpenMediaMatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import os
import sys

import flask
import flask_migrate
import flask_sqlalchemy
Expand All @@ -16,7 +18,14 @@ def create_app():
Create and configure the Flask app
"""
app = flask.Flask(__name__)
app.config.from_envvar("OMM_CONFIG")
if "OMM_CONFIG" in os.environ:
app.config.from_envvar("OMM_CONFIG")
elif sys.argv[0].endswith("/flask"): # Default for flask CLI
# The devcontainer settings. If you are using the CLI outside
# the devcontainer and getting an error, just override the env
app.config.from_pyfile("/workspace/.devcontainer/omm_config.py")
else:
raise RuntimeError("No flask config given - try populating OMM_CONFIG env")
app.config.update(
SQLALCHEMY_DATABASE_URI=app.config.get("DATABASE_URI"),
SQLALCHEMY_TRACK_MODIFICATIONS=False,
Expand Down
25 changes: 25 additions & 0 deletions open-media-match/src/OpenMediaMatch/app_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Accessors for various "global" resources, usually cached by request lifetime
I can't tell if these should just be in app.py, so I'm sticking it here for now,
since one advantage of putting these in functions is we can type the output.
"""

from flask import g

from OpenMediaMatch.storage.interface import IUnifiedStore
from OpenMediaMatch.storage.default import DefaultOMMStore


def get_storage() -> IUnifiedStore:
"""
Get the storage object, which is just a wrapper around the real storage.
"""
if "storage" not in g:
# dougneal, you'll need to eventually add constructor arguments
# for this to pass in the postgres/database object. We're just
# hiding flask bits from pytx bits
g.storage = DefaultOMMStore()
return g.storage
91 changes: 77 additions & 14 deletions open-media-match/src/OpenMediaMatch/blueprints/hashing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Endpoints for hashing content
"""

from pathlib import Path
import tempfile
import typing as t

from flask import Blueprint
from flask import abort, request
from flask import abort, request, current_app
import requests

from threatexchange.content_type.content_base import ContentType
from threatexchange.signal_type.signal_base import FileHasher, SignalType

from OpenMediaMatch import app_resources

bp = Blueprint("hashing", __name__)

Expand All @@ -10,18 +26,65 @@ def hash_media():
Fetch content and return its hash.
TODO: implement
"""

content_type = _parse_request_content_type()
signal_types = _parse_request_signal_type(content_type)

media_url = request.args.get("url", None)
if media_url is None:
# path is required, otherwise we don't know what we're hashing.
# TODO: a more helpful message
abort(400)

hash_types = request.args.get("types", None)
if hash_types is not None:
# TODO: parse this into a list of hash types
pass

# TODO
# - download the media
# - decode it
# - hash it
abort(400, "url is required")

download_resp = requests.get(media_url, allow_redirects=True, timeout=30 * 1000)
download_resp.raise_for_status()

ret = {}

# For images, we may need to copy the file suffix (.png, jpeg, etc) for it to work
with tempfile.NamedTemporaryFile("wb") as tmp:
current_app.logger.debug("Writing to %s", tmp.name)
tmp.write(download_resp.content)
path = Path(tmp.name)
for st in signal_types.values():
# At this point, every BytesHasher is a FileHasher, but we could
# explicitly pull those out to avoiding storing any copies of
# data locally, even temporarily
if issubclass(st, FileHasher):
ret[st.get_name()] = st.hash_from_file(path)
return ret


def _parse_request_content_type() -> ContentType:
storage = app_resources.get_storage()
arg = request.args.get("content_type", "")
content_type_config = storage.get_content_type_configs().get(arg)
if content_type_config is None:
abort(400, f"no such content_type: '{arg}'")

if not content_type_config.enabled:
abort(400, f"content_type {arg} is disabled")

return content_type_config.content_type


def _parse_request_signal_type(content_type: ContentType) -> t.Mapping[str, SignalType]:
storage = app_resources.get_storage()
signal_types = storage.get_enabled_signal_types_for_content_type(content_type)
if not signal_types:
abort(500, "No signal types configured!")
signal_type_args = request.args.get("types", None)
if signal_type_args is None:
return signal_types

ret = {}
for st_name in signal_type_args.split(","):
st_name = st_name.strip()
if not st_name:
continue
if st_name not in signal_types:
abort(400, f"signal type '{st_name}' doesn't exist or is disabled")
ret[st_name] = signal_types[st_name]

if not ret:
abort(400, "empty signal type selection")

return ret
6 changes: 3 additions & 3 deletions open-media-match/src/OpenMediaMatch/migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
from logging.config import fileConfig

from flask import current_app

from alembic import context
from flask import current_app

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# It's also impossible to type!
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
fileConfig(config.config_file_name) # type: ignore
logger = logging.getLogger("alembic.env")


Expand Down
6 changes: 3 additions & 3 deletions open-media-match/src/OpenMediaMatch/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from . import database as db
from OpenMediaMatch import database as db


class Bank(db.Model):
class Bank(db.Model): # type: ignore[name-defined] # mypy not smart enough
__tablename__ = "banks"
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
name = db.Column(db.String(255), nullable=False)
enabled = db.Column(db.Boolean, nullable=False)


class Hash(db.Model):
class Hash(db.Model): # type: ignore[name-defined] # mypy not smart enough
__tablename__ = "hashes"
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
enabled = db.Column(db.Boolean, nullable=False)
Expand Down
26 changes: 22 additions & 4 deletions open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ContentTypeConfig:

# Content types that are not enabled should not be used in hashing/matching
enabled: bool
signal_type: ContentType
content_type: ContentType


class IContentTypeConfigStore(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -62,9 +62,27 @@ class ISignalTypeConfigStore(metaclass=abc.ABCMeta):

@abc.abstractmethod
def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]:
"""
Return all installed signal types.
"""
"""Return all installed signal types."""

@t.final
def get_enabled_signal_types(self) -> t.Mapping[str, SignalType]:
"""Helper shortcut for getting only enabled SignalTypes"""
return {
k: v.signal_type
for k, v in self.get_signal_type_configs().items()
if v.enabled
}

@t.final
def get_enabled_signal_types_for_content_type(
self, content_type: ContentType
) -> t.Mapping[str, SignalType]:
"""Helper shortcut for getting enabled types for a piece of content"""
return {
k: v.signal_type
for k, v in self.get_signal_type_configs().items()
if v.enabled and content_type in v.signal_type.get_content_types()
}


class ISignalExchangeConfigStore(metaclass=abc.ABCMeta):
Expand Down

0 comments on commit cc04f20

Please sign in to comment.