Skip to content

Commit

Permalink
[omm] Fetch interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies committed Sep 25, 2023
1 parent 35a8ed4 commit 013e5d8
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
Empty file.
36 changes: 36 additions & 0 deletions open-media-match/src/OpenMediaMatch/background_tasks/fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import typing as t

from OpenMediaMatch.storage.interface import ICollaborationStore, SignalType
from threatexchange.exchanges.fetch_state import CollaborationConfigBase


def fetch_all(
collab_store: ICollaborationStore,
enabled_signal_types: t.Dict[str, t.Type[SignalType]],
) -> None:
"""
For all collaborations registered with OMM, fetch()
"""
collabs = collab_store.get_collaborations()
for c in collabs.values():
fetch(collab_store, enabled_signal_types, c)


def fetch(
config: ICollaborationStore,
enabled_signal_types: t.Dict[str, t.Type[SignalType]],
collab: CollaborationConfigBase,
):
"""
Fetch data from
1. Attempt to authenticate with that collaboration's API
using stored credentials.
2. Load the fetch checkpoint from storage
3. Resume the fetch at the checkpoint
4. Download new data
5. Send the new data to storage (saving the new checkpoint)
"""
# TODO
53 changes: 52 additions & 1 deletion open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from threatexchange.content_type.content_base import ContentType
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.exchanges.fetch_state import (
FetchCheckpointBase,
CollaborationConfigBase,
)
from threatexchange.exchanges.signal_exchange_api import (
TSignalExchangeAPICls,
)
Expand Down Expand Up @@ -122,14 +126,61 @@ def get_signal_type_index(
"""


# TODO - index, collaborations, banks, OMM-specific
class ICollaborationStore(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_collaborations(self) -> t.Dict[str, CollaborationConfigBase]:
"""
Get all collaboration configs.
Collaboration configs control the syncing of data from external
sources to banks of labeled content locally.
"""

@abc.abstractmethod
def get_collab_fetch_checkpoint(
self, collab: CollaborationConfigBase
) -> t.Optional[FetchCheckpointBase]:
"""
Get the last saved checkpoint for the fetch of this collaboration.
If there is no previous fetch, returns None, indicating the fetch
should start from the beginning.
"""

@abc.abstractmethod
def commit_collab_fetch_data(
self,
collab: CollaborationConfigBase,
dat: t.Dict[str, t.Any],
checkpoint: FetchCheckpointBase,
):
"""
Commit a sequentially fetched set of data from a fetch().
Advances the checkpoint if it's different than the previous one.
"""

@abc.abstractmethod
def get_collab_data(
self,
collab_name: str,
key: str,
checkpoint: FetchCheckpointBase,
) -> t.Any:
"""
Get API-specific collaboration data by key.
"""


# TODO - banks


class IUnifiedStore(
IContentTypeConfigStore,
ISignalTypeConfigStore,
ISignalExchangeConfigStore,
ISignalTypeIndexStore,
ICollaborationStore,
metaclass=abc.ABCMeta,
):
"""
Expand Down
38 changes: 38 additions & 0 deletions open-media-match/src/OpenMediaMatch/storage/mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.exchanges.fetch_state import (
FetchCheckpointBase,
CollaborationConfigBase,
)

from OpenMediaMatch.storage import interface
from OpenMediaMatch.storage.interface import SignalTypeConfig
Expand All @@ -34,6 +38,7 @@ def get_signal_type_configs(self) -> t.Mapping[str, SignalTypeConfig]:
s_types: t.Sequence[t.Type[SignalType]] = (PdqSignal, VideoMD5Signal)
return {s.get_name(): interface.SignalTypeConfig(True, s) for s in s_types}

# Index
def get_signal_type_index(
self, signal_type: type[SignalType]
) -> t.Optional[SignalTypeIndex[int]]:
Expand All @@ -43,3 +48,36 @@ def get_signal_type_index(
set(signal_type.get_examples()), start=1
)
)

# Collabs
def get_collaborations(self) -> t.Dict[str, CollaborationConfigBase]:
cfg_cls = StaticSampleSignalExchangeAPI.get_config_cls()
return {
c.name: c
for c in (
cfg_cls(
"c-TEST", api=StaticSampleSignalExchangeAPI.get_name(), enabled=True
),
)
}

def get_collab_fetch_checkpoint(
self, collab: CollaborationConfigBase
) -> t.Optional[FetchCheckpointBase]:
return None

def commit_collab_fetch_data(
self,
collab: CollaborationConfigBase,
dat: t.Dict[str, t.Any],
checkpoint: FetchCheckpointBase,
):
pass

def get_collab_data(
self,
collab_name: str,
key: str,
checkpoint: FetchCheckpointBase,
) -> t.Any:
return None

0 comments on commit 013e5d8

Please sign in to comment.