-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytx] Implement a new cleaner PDQ index solution #1695
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import typing as t | ||
import faiss | ||
import pickle | ||
import numpy as np | ||
|
||
from threatexchange.signal_type.index import ( | ||
IndexMatchUntyped, | ||
SignalSimilarityInfoWithIntDistance, | ||
SignalTypeIndex, | ||
T as IndexT, | ||
SignalSimilarityInfo, | ||
IndexMatch, | ||
) | ||
|
||
T = t.TypeVar("T") | ||
DEFAULT_MATCH_DIST = 31 | ||
DIMENSIONALITY = 256 | ||
|
||
|
||
class _PDQHashIndex: | ||
""" | ||
A wrapper around the faiss index for pickle serialization | ||
""" | ||
|
||
def __init__(self, faiss_index: faiss.Index) -> None: | ||
self.faiss_index = faiss_index | ||
|
||
def add(self, pdq_strings: t.Sequence[str]) -> None: | ||
""" | ||
Add PDQ hashes to the FAISS index. | ||
|
||
Args: | ||
pdq_strings (Sequence[str]): PDQ hash strings to add | ||
""" | ||
vectors = self._convert_pdq_strings_to_ndarray(pdq_strings) | ||
self.faiss_index.add(vectors) | ||
|
||
def search( | ||
self, queries: t.Sequence[str], threshold: int = DEFAULT_MATCH_DIST | ||
) -> t.List[t.List[t.Any]]: | ||
""" | ||
Search the FAISS index for matches to the given PDQ queries. | ||
|
||
Args: | ||
queries (Sequence[str]): The PDQ signal strings to search for. | ||
threshold (int): The maximum distance threshold for matches. | ||
|
||
Returns: | ||
|
||
""" | ||
query_array: np.ndarray = self._convert_pdq_strings_to_ndarray(queries) | ||
limits, distances, indices = self.faiss_index.range_search( | ||
query_array, threshold + 1 | ||
) | ||
|
||
results: t.List[t.List[t.Any]] = [] | ||
for i in range(len(queries)): | ||
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] | ||
dists = [dist for dist in distances[limits[i] : limits[i + 1]]] | ||
results.append(list(zip(matches, dists))) | ||
return results | ||
|
||
def __getstate__(self): | ||
data = faiss.serialize_index(self.faiss_index) | ||
return data | ||
|
||
def __setstate__(self, data): | ||
self.faiss_index = faiss.deserialize_index(data) | ||
|
||
def _convert_pdq_strings_to_ndarray( | ||
self, pdq_strings: t.Sequence[str] | ||
) -> np.ndarray: | ||
""" | ||
Convert multiple PDQ hash strings to a numpy array. | ||
|
||
Args: | ||
pdq_strings (Sequence[str]): A sequence of 64-character hexadecimal PDQ hash strings | ||
|
||
Returns: | ||
np.ndarray: A 2D array of shape (n_queries, 256) where each row is the full PDQ hash as a bit array | ||
""" | ||
hash_arrays = [] | ||
for pdq_str in pdq_strings: | ||
try: | ||
# Convert hex string to integer | ||
hash_int = int(pdq_str, 16) | ||
# Convert to binary string, padding to ensure 256 bits | ||
binary_str = format(hash_int, "0256b") | ||
# Convert to numpy array | ||
hash_array = np.array( | ||
[int(bit) for bit in binary_str], dtype=np.float32 | ||
) | ||
hash_arrays.append(hash_array) | ||
except (ValueError, TypeError) as e: | ||
raise ValueError(f"Invalid PDQ hash string: {pdq_str}") from e | ||
|
||
# Convert list of arrays to a single 2D array | ||
return np.array(hash_arrays, dtype=np.float32) | ||
|
||
|
||
Self = t.TypeVar("Self", bound="SignalTypeIndex2") | ||
|
||
PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] | ||
|
||
|
||
class SignalTypeIndex2(t.Generic[T]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blockin q: Why are we creating this class? Issue #1613 is about creating a new PDQ index, but this is creating a new interface for the overall index class which assumes faiss compatibility, which may not be true for every signal type. |
||
def __init__( | ||
self, | ||
threshold: int = DEFAULT_MATCH_DIST, | ||
faiss_index: t.Optional[faiss.Index] = None, | ||
) -> None: | ||
""" | ||
Initialize the PDQ index. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This refers to PDQ, but you've put it at the top level class. |
||
|
||
Args: | ||
threshold (int): The maximum distance threshold for matches. | ||
faiss_index (faiss.Index): An optional pre-existing FAISS index to use. | ||
""" | ||
super().__init__() | ||
if faiss_index is None: | ||
# Use a simple brute-force FAISS index by default | ||
faiss_index = faiss.IndexFlatL2(DIMENSIONALITY) | ||
self.faiss_index = _PDQHashIndex(faiss_index) | ||
self.threshold = threshold | ||
self._deduper: t.Set[str] = set() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the meeting: Should store this as a mapping from hash => faiss_id
|
||
self._entries: t.List[t.List[T]] = [] | ||
|
||
def query(self, query: str) -> t.List[PDQIndexMatch[T]]: | ||
results = self.faiss_index.search([query], self.threshold) | ||
return [ | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(distance=int(distf)), entry | ||
) | ||
for idx, distf in results[0] | ||
for entry in self._entries[idx] | ||
] | ||
|
||
def add(self, pdq_hash: str, entry: T) -> None: | ||
""" | ||
Add a PDQ hash and its associated entry to the index. | ||
|
||
Args: | ||
pdq_hash (str): The PDQ hash string | ||
entry (T): The associated entry data | ||
""" | ||
if pdq_hash not in self._deduper: | ||
self._deduper.add(pdq_hash) | ||
self.faiss_index.add([pdq_hash]) | ||
self._entries.append([entry]) | ||
else: | ||
# If hash exists, append entry to existing entries | ||
idx = list(self._deduper).index(pdq_hash) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a very expensive call! O(n) to convert _deduper into a list, then O(n) to find which index is in the list. I also don't think it will be stable, since the set is not ordered! |
||
self._entries[idx].append(entry) | ||
|
||
def serialize(self, fout: t.BinaryIO) -> None: | ||
""" | ||
Serialize the PDQ index to a binary stream. | ||
""" | ||
fout.write(pickle.dumps(self)) | ||
|
||
@classmethod | ||
def deserialize(cls, fin: t.BinaryIO) -> "SignalTypeIndex2[T]": | ||
""" | ||
Deserialize a PDQ index from a binary stream. | ||
""" | ||
return pickle.loads(fin.read()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import pytest | ||
import io | ||
import faiss | ||
from threatexchange.signal_type.index2 import ( | ||
SignalTypeIndex2, | ||
_PDQHashIndex, | ||
DIMENSIONALITY, | ||
DEFAULT_MATCH_DIST, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def empty_index(): | ||
"""Fixture for an empty index.""" | ||
return SignalTypeIndex2[str]() | ||
|
||
|
||
@pytest.fixture | ||
def custom_index_with_threshold(): | ||
"""Fixture for an index with custom index and threshold.""" | ||
custom_index = faiss.IndexFlatL2(DIMENSIONALITY + 1) | ||
custom_threshold = DEFAULT_MATCH_DIST + 1 | ||
return SignalTypeIndex2[str](faiss_index=custom_index, threshold=custom_threshold) | ||
|
||
|
||
@pytest.fixture | ||
def sample_index(): | ||
"""Fixture for an index with a small sample set.""" | ||
index = SignalTypeIndex2[str]() | ||
pdq_hashes = [ | ||
"f" * 64, # All f's | ||
"0" * 64, # All 0's | ||
"a" * 64, # All a's | ||
"f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", # Sample hash | ||
] | ||
return index, pdq_hashes | ||
|
||
|
||
def test_init(empty_index) -> None: | ||
assert empty_index.threshold == DEFAULT_MATCH_DIST | ||
assert isinstance(empty_index.faiss_index, _PDQHashIndex) | ||
assert isinstance(empty_index.faiss_index.faiss_index, faiss.IndexFlatL2) | ||
assert empty_index.faiss_index.faiss_index.d == DIMENSIONALITY | ||
assert empty_index._deduper == set() | ||
assert empty_index._entries == [] | ||
|
||
|
||
def test_serialize_deserialize(empty_index) -> None: | ||
buffer = io.BytesIO() | ||
empty_index.serialize(buffer) | ||
buffer.seek(0) | ||
deserialized_index: SignalTypeIndex2[str] = SignalTypeIndex2.deserialize(buffer) | ||
|
||
assert isinstance(deserialized_index, SignalTypeIndex2) | ||
assert deserialized_index.threshold == empty_index.threshold | ||
assert isinstance(deserialized_index.faiss_index, _PDQHashIndex) | ||
assert isinstance(deserialized_index.faiss_index.faiss_index, faiss.IndexFlatL2) | ||
assert deserialized_index.faiss_index.faiss_index.d == DIMENSIONALITY | ||
assert deserialized_index._deduper == empty_index._deduper | ||
assert deserialized_index._entries == empty_index._entries | ||
|
||
|
||
def test_serialize_deserialize_with_custom_index_threshold( | ||
custom_index_with_threshold, | ||
) -> None: | ||
buffer = io.BytesIO() | ||
custom_index_with_threshold.serialize(buffer) | ||
buffer.seek(0) | ||
deserialized_index: SignalTypeIndex2[int] = SignalTypeIndex2.deserialize(buffer) | ||
|
||
assert isinstance(deserialized_index, SignalTypeIndex2) | ||
assert deserialized_index.threshold == custom_index_with_threshold.threshold | ||
assert isinstance(deserialized_index.faiss_index, _PDQHashIndex) | ||
assert isinstance(deserialized_index.faiss_index.faiss_index, faiss.IndexFlatL2) | ||
assert deserialized_index.faiss_index.faiss_index.d == DIMENSIONALITY + 1 | ||
assert deserialized_index._deduper == custom_index_with_threshold._deduper | ||
assert deserialized_index._entries == custom_index_with_threshold._entries | ||
|
||
|
||
def test_empty_index_query(empty_index): | ||
"""Test querying an empty index.""" | ||
query_hash = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" | ||
|
||
# Query should return empty list | ||
results = empty_index.query(query=query_hash) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_sample_set_exact_match(sample_index): | ||
"""Test exact matches in sample set.""" | ||
index, pdq_hashes = sample_index | ||
|
||
# Add hashes to index | ||
for hash_str in pdq_hashes: | ||
index.add(hash_str, hash_str) # Using hash as its own identifier | ||
|
||
# Query with existing hash | ||
query_hash = pdq_hashes[0] | ||
results = index.query(query_hash) | ||
|
||
assert len(results) == 1 | ||
assert ( | ||
results[0].similarity_info.distance == 0 | ||
) # Exact match should have distance 0 | ||
|
||
|
||
def test_sample_set_near_match(sample_index): | ||
"""Test near matches in sample set.""" | ||
index, pdq_hashes = sample_index | ||
|
||
# Add hashes to index | ||
for hash_str in pdq_hashes: | ||
index.add(hash_str, hash_str) # Using hash as its own identifier | ||
|
||
# Create a near-match by flipping a few bits | ||
base_hash = pdq_hashes[0] | ||
near_hash = hex(int(base_hash, 16) ^ 0xF)[2:].zfill(64) # Flip 4 bits | ||
|
||
results = index.query(near_hash) | ||
assert len(results) > 0 # Should find near matches | ||
assert results[0].similarity_info.distance > 0 | ||
|
||
|
||
def test_sample_set_threshold(sample_index): | ||
"""Test distance threshold behavior.""" | ||
_, pdq_hashes = sample_index | ||
|
||
narrow_index = SignalTypeIndex2[str](threshold=10) # Strict matching | ||
wide_index = SignalTypeIndex2[str](threshold=50) # Loose matching | ||
|
||
for hash_str in pdq_hashes: | ||
narrow_index.add(hash_str, hash_str) | ||
wide_index.add(hash_str, hash_str) | ||
|
||
# Create a test hash with known distance | ||
base_hash = pdq_hashes[0] | ||
test_hash = hex(int(base_hash, 16) ^ ((1 << 20) - 1))[2:].zfill( | ||
64 | ||
) # ~20 bits different | ||
|
||
narrow_results = narrow_index.query(test_hash) | ||
wide_results = wide_index.query(test_hash) | ||
|
||
assert len(wide_results) > len(narrow_results) # Wide threshold should match more | ||
|
||
|
||
def test_duplicate_handling(sample_index): | ||
"""Test how the index handles duplicate entries.""" | ||
index, pdq_hashes = sample_index | ||
|
||
# Add same hash multiple times | ||
test_hash = pdq_hashes[0] | ||
for i in range(3): | ||
index.add(test_hash, f"entry_{i}") | ||
|
||
results = index.query(test_hash) | ||
|
||
# Should find all entries associated with the hash | ||
assert len(results) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused?