Skip to content

Commit

Permalink
create dataclass IndexMatchWithRotation and use that in match_cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
haianhng31 committed Oct 30, 2024
1 parent 13bf6e7 commit 594782b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 42 deletions.
103 changes: 62 additions & 41 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@
import logging
import pathlib
import typing as t
import io
import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchWithRotation,
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
Expand Down Expand Up @@ -52,10 +56,6 @@ class MatchCommand(command_base.Command):
$ threatexchange match text -- This is my cool text
```
# Additional options:
--rotation: For photo content, generate and match all 8 simple rotations
(0°, 90°, 180°, 270°, flip X, flip Y, flip diagonal +1, flip diagonal -1)
# Output
The output of this command is in the following format:
Expand Down Expand Up @@ -132,10 +132,10 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations"
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
Expand All @@ -147,7 +147,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = True
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -164,11 +164,10 @@ def __init__(
f"apply to {content_type.get_name()}",
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type",
2
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
Expand Down Expand Up @@ -215,24 +214,35 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results_from_hashes = _match_hashes(path, s_type, index)
results: t.Sequence[IndexMatchWithRotation] = [
IndexMatchWithRotation(match=match)
for match in results_from_hashes
]
else:
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
if isinstance(r, IndexMatchWithRotation):
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
rotation_info = f" [{r.rotation_type.name}]"
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(
r.match.similarity_info.pretty_str().split()
)

elif isinstance(r, IndexMatch):
metadatas = r.metadata
rotation_info = ""
distance_str = "".join(r.similarity_info.pretty_str().split())

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)

# Add rotation information if possible
rotation_info = ""
if hasattr(r.similarity_info, "rotation"):
rotation_info = f" [{r.similarity_info.rotation.name}]"

# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())
print(
s_type.get_name(),
distance_str + rotation_info,
Expand All @@ -242,35 +252,46 @@ def execute(self, settings: CLISettings) -> None:


def _match_file(
path: pathlib.Path,
s_type: t.Type[SignalType],
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False
) -> t.Sequence[IndexMatch]:
rotations: bool = False,
) -> t.Sequence[IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)

if not rotations or s_type != PhotoContent:
return index.query(s_type.hash_from_file(path))

matches = index.query(s_type.hash_from_file(path))
return [IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotations = PhotoContent.all_simple_rotations(image_data)

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotations.items():
# Create a temporary BytesIO object to simulate a file
temp_buffer = io.BytesIO(rotated_bytes)
matches = index.query(s_type.hash_from_file(temp_buffer))
for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))
temp_file_path.unlink() # Clean up the temporary file

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

# Add rotation information if any matches were found
for match in matches:
match.similarity_info.rotation = rotation_type

all_matches.extend(matches)
all_matches.extend(matches_with_rotations)

return all_matches

Expand Down
3 changes: 2 additions & 1 deletion python-threatexchange/threatexchange/content_type/photo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from PIL import Image
import io
import typing as t

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
"""
Generate the 8 naive rotations of an image.
Expand Down
7 changes: 7 additions & 0 deletions python-threatexchange/threatexchange/signal_type/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pickle
import typing as t

from threatexchange.content_type.content_base import RotationType

T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
Expand Down Expand Up @@ -139,6 +140,12 @@ def __eq__(self, other: t.Any) -> bool:
IndexMatch = IndexMatchUntyped[SignalSimilarityInfo, T]


@dataclass
class IndexMatchWithRotation(t.Generic[T]):
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: RotationType = RotationType.ORIGINAL


Self = t.TypeVar("Self", bound="SignalTypeIndex")


Expand Down

0 comments on commit 594782b

Please sign in to comment.