Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx][backwards incompatible] Split Update Format and Index Format #1014

Merged
merged 6 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 23 additions & 58 deletions python-threatexchange/threatexchange/cli/cli_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
3. Index state - serializations of indexes for SignalType
"""

import json
import pickle
import pathlib
import typing as t
import dataclasses
import logging

from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.cli.exceptions import CommandError
from threatexchange.cli import dataclass_json
from threatexchange.fetcher.collab_config import CollaborationConfigBase
from threatexchange.fetcher.fetch_state import (
FetchCheckpointBase,
FetchDelta,
FetchedSignalMetadata,
)
from threatexchange.fetcher.simple import state as simple_state
Expand Down Expand Up @@ -124,60 +123,45 @@ def clear(self, collab: CollaborationConfigBase) -> None:
def _read_state(
self,
collab_name: str,
) -> t.Optional[
t.Tuple[
t.Dict[str, t.Dict[str, FetchedSignalMetadata]],
FetchCheckpointBase,
]
]:
) -> t.Optional[simple_state.T_FetchDelta]:
file = self.collab_file(collab_name)
if not file.is_file():
return None
try:
with file.open("r") as f:
json_dict = json.load(f)

checkpoint = dataclass_json.dataclass_load_dict(
json_dict=json_dict[self.JSON_CHECKPOINT_KEY],
cls=self.api_cls.get_checkpoint_cls(),
with file.open("rb") as f:
delta = pickle.load(f)

assert isinstance(delta, FetchDelta), "Unexpected class type?"
delta = t.cast(simple_state.T_FetchDelta, delta)
assert (
delta.next_checkpoint().__class__.__name__
== self.api_cls.get_checkpoint_cls().__name__
), "wrong checkpoint class?"

logging.debug(
"Loaded %s with %d records", collab_name, delta.record_count()
)
records = json_dict[self.JSON_RECORDS_KEY]

logging.debug("Loaded %s with records for: %s", collab_name, list(records))
# Minor stab at lowering memory footprint by converting kinda
# inline
for stype in list(records):
records[stype] = {
signal: dataclass_json.dataclass_load_dict(
json_dict=json_record,
cls=self.api_cls.get_record_cls(),
)
for signal, json_record in records[stype].items()
}
return records, checkpoint
return delta
except Exception:
logging.exception("Failed to read state for %s", collab_name)
raise CommandError(
f"Failed to read state for {collab_name}. "
"You might have to delete it with `threatexchange fetch --clear`"
)

def _write_state( # type: ignore[override] # fix with generics on base
def _write_state( # type: ignore[override] # Fix in followup PR
self,
collab_name: str,
updates_by_type: t.Dict[str, t.Dict[str, FetchedSignalMetadata]],
checkpoint: FetchCheckpointBase,
delta: simple_state.SimpleFetchDelta[
FetchCheckpointBase, FetchedSignalMetadata
],
) -> None:
file = self.collab_file(collab_name)
if not file.parent.exists():
file.parent.mkdir(parents=True)

record_sanity_check = next(
(
record
for records in updates_by_type.values()
for record in records.values()
),
(record for record in delta.update_record.values()),
None,
)

Expand All @@ -188,26 +172,7 @@ def _write_state( # type: ignore[override] # fix with generics on base
f"Record cls: want {self.api_cls.get_record_cls().__name__} "
f"got {record_sanity_check.__class__.__name__}"
)

json_dict = {
self.JSON_CHECKPOINT_KEY: dataclasses.asdict(checkpoint),
self.JSON_RECORDS_KEY: {
stype: {
s: dataclasses.asdict(record)
for s, record in signal_to_record.items()
}
for stype, signal_to_record in updates_by_type.items()
},
}

tmpfile = file.with_name(f".{file.name}")

with tmpfile.open("w") as f:
json.dump(json_dict, f, indent=2, default=_json_set_default)
with tmpfile.open("wb") as f:
pickle.dump(delta, f)
tmpfile.rename(file)


def _json_set_default(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
11 changes: 7 additions & 4 deletions python-threatexchange/threatexchange/cli/fetch_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from threatexchange.fetcher.fetch_state import (
FetchCheckpointBase,
FetchDelta,
FetchedSignalMetadata,
FetchedStateStoreBase,
)
from threatexchange.cli import command_base
Expand Down Expand Up @@ -169,13 +170,15 @@ def execute_for_collab(

try:
while not self.has_hit_limits():
delta: FetchDelta[FetchCheckpointBase] = fetcher.fetch_once(
delta: FetchDelta[
FetchCheckpointBase, FetchedSignalMetadata
] = fetcher.fetch_once(
settings.get_all_signal_types(), collab, checkpoint
)
logging.info("Fetched %d records", delta.record_count())
checkpoint = delta.next_checkpoint()
self._fetch_progress(delta.record_count(), checkpoint)
assert checkpoint is not None # Infinite loop protection
next_checkpoint = delta.next_checkpoint()
self._fetch_progress(delta.record_count(), next_checkpoint)
assert next_checkpoint is not None # Infinite loop protection
store.merge(collab, delta)
if not delta.has_more():
break
Expand Down
3 changes: 0 additions & 3 deletions python-threatexchange/threatexchange/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
"""

import argparse
from dataclasses import dataclass
from distutils import extension
import logging
import inspect
import os
import sys
from textwrap import dedent
import typing as t
import pathlib

Expand Down
6 changes: 3 additions & 3 deletions python-threatexchange/threatexchange/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def argparse_choices_pre_type(choices: t.List[str], type: t.Callable[[str], t.An
def ret(s: str):
if s not in choices:
raise argparse.ArgumentTypeError(
"invalid choice: %s (choose from %s)",
s,
", ".join(repr(c) for c in choices),
"invalid choice: {} (choose from {})".format(
s, ", ".join(repr(c) for c in choices)
),
)
return type(s)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def fetch_once(
with path.open("r") as f:
lines = f.readlines()

updates = {}
updates: t.Dict[t.Tuple[str, str], t.Optional[state.FetchedSignalMetadata]] = {}
for line in lines:
signal_type = collab.signal_type
signal = line.strip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def fetch_once(
for stype in supported_signal_types:
sample_signals.extend(_signals(stype))

return SimpleFetchDelta(
dict(sample_signals),
updates: t.Dict[
t.Tuple[str, str], t.Optional[state.FetchedSignalMetadata]
] = dict(sample_signals)

return TDelta(
updates,
state.FetchCheckpointBase(),
done=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def fetch_iter(
start_time = checkpoint.update_time
for result in self.api.fetch_hashes_iter(start_timestamp=start_time):
translated = (_get_delta_mapping(r) for r in result.hashRecords)
yield SimpleFetchDelta(
yield SimpleFetchDelta[StopNCIICheckpoint, StopNCIISignalMetadata](
dict(t for t in translated if t[0][0]),
StopNCIICheckpoint.from_stopncii_fetch(result),
done=not result.hasMoreRecords,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import pytest
import typing as t

Expand Down Expand Up @@ -37,7 +39,7 @@ def test_fetch(fetcher: SignalExchangeAPI, monkeypatch: pytest.MonkeyPatch):
assert checkpoint.update_time == 1625175071
assert checkpoint.last_fetch_time == 10**8

updates = delta.get_as_update_dict()
updates = delta.update_record
assert len(updates) == 2
assert {t[0] for t in updates} == {"pdq"}

Expand All @@ -57,7 +59,7 @@ def test_fetch(fetcher: SignalExchangeAPI, monkeypatch: pytest.MonkeyPatch):
delta = t.cast(SimpleFetchDelta, fetcher.fetch_once([], collab, None))
assert delta.has_more() is False
assert delta.record_count() == 1
updates = delta.get_as_update_dict()
updates = delta.update_record
assert len(updates) == 1
assert "pdq" == tuple(updates)[0][0]
a = t.cast(StopNCIISignalMetadata, tuple(updates.values())[0])
Expand Down
12 changes: 12 additions & 0 deletions python-threatexchange/threatexchange/fetcher/fetch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def report_false_positive(
)


# A convenience helper since mypy can't intuit that bound != t.Any
# For methods like get_checkpoint_cls
TSignalExchangeAPI = SignalExchangeAPI[
CollaborationConfigBase,
state.FetchCheckpointBase,
state.FetchedSignalMetadata,
state.FetchDelta[state.FetchCheckpointBase, state.FetchedSignalMetadata],
]

TSignalExchangeAPICls = t.Type[TSignalExchangeAPI]


class SignalExchangeAPIWithIterFetch(
SignalExchangeAPI[
TCollabConfig, state.TFetchCheckpoint, state.TFetchedSignalMetadata, TFetchDelta
Expand Down
48 changes: 30 additions & 18 deletions python-threatexchange/threatexchange/fetcher/fetch_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,16 @@ def __str__(self) -> str:
)


class FetchDelta(t.Generic[TFetchCheckpoint]):
class FetchDelta(t.Generic[TFetchCheckpoint, TFetchedSignalMetadata]):
"""
Contains the result of a fetch.

You'll need to extend this, but it only to be interpretable by your
API's version of FetchedState
In order to make naive storage (such as on the CLI) work, implementations
of this class must be pickle-able.

Note that the way that this class organizes and stores data does not
need to be the same way that that the index class organizes data,
which is hash => Record.
"""

def record_count(self) -> int:
Expand All @@ -199,28 +203,36 @@ def next_checkpoint(self) -> TFetchCheckpoint:

def has_more(self) -> bool:
"""
Returns true if the API has no more data at this time.
Returns false if the API has no more data at this time.
"""
raise NotImplementedError

def merge(self: Self, newer: Self) -> None:
"""
Merge the content of a subsequent fetch() call into this one.

class FetchDeltaWithUpdateStream(
t.Generic[TFetchCheckpoint, TFetchedSignalMetadata], FetchDelta[TFetchCheckpoint]
):
"""
For most APIs, they can represented in a simple update stream.
Different APIs might have different approaches to merging.

This allows naive implementations for storage.
"""
You can assume the caller has kept track, and is only
merging in sequential order.

delta_1 = api.fetch_once(...)
delta_2 = api.fetch_once(..., delta_1.checkpoint)
delta_3 = api.fetch_once(..., delta_2.checkpoint)

def get_as_update_dict(
self,
) -> t.Mapping[t.Tuple[str, str], t.Optional[TFetchedSignalMetadata]]:
delta_1.merge(delta_2)
delta_1.merge(delta_3)
"""
Returns the contents of the delta as
(signal_type, signal_str) => record
If the record is set to None, this indicates the record should be
deleted if it exists.
raise NotImplementedError

def get_for_signal_type(
self, signal_type: t.Type[SignalType]
) -> t.Dict[str, TFetchedSignalMetadata]:
"""
Get as a map of signal => Metadata

This powers simple storage solutions, and provides the mapping
from how the API provides update to how the index needs.
"""
raise NotImplementedError

Expand Down
Loading