Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytx] Implement a new cleaner PDQ index solution #1695

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions python-threatexchange/threatexchange/signal_type/index2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import typing as t
import faiss
import pickle
import numpy as np

from threatexchange.signal_type.index import (
IndexMatchUntyped,
SignalSimilarityInfoWithIntDistance,
SignalTypeIndex,
T as IndexT,
SignalSimilarityInfo,
IndexMatch,
)

T = t.TypeVar("T")
DEFAULT_MATCH_DIST = 31
DIMENSIONALITY = 256


class _PDQHashIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
self.faiss_index = faiss_index

def add(self, pdq_strings: t.Sequence[str]) -> None:
"""
Add PDQ hashes to the FAISS index.

Args:
pdq_strings (Sequence[str]): PDQ hash strings to add
"""
vectors = self._convert_pdq_strings_to_ndarray(pdq_strings)
self.faiss_index.add(vectors)

def search(
self, queries: t.Sequence[str], threshold: int = DEFAULT_MATCH_DIST
) -> t.List[t.List[t.Any]]:
"""
Search the FAISS index for matches to the given PDQ queries.

Args:
queries (Sequence[str]): The PDQ signal strings to search for.
threshold (int): The maximum distance threshold for matches.

Returns:

"""
query_array: np.ndarray = self._convert_pdq_strings_to_ndarray(queries)
limits, distances, indices = self.faiss_index.range_search(
query_array, threshold + 1
)

results: t.List[t.List[t.Any]] = []
for i in range(len(queries)):
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]]
dists = [dist for dist in distances[limits[i] : limits[i + 1]]]
results.append(list(zip(matches, dists)))
return results

def __getstate__(self):
data = faiss.serialize_index(self.faiss_index)
return data

def __setstate__(self, data):
self.faiss_index = faiss.deserialize_index(data)

def _convert_pdq_strings_to_ndarray(
self, pdq_strings: t.Sequence[str]
) -> np.ndarray:
"""
Convert multiple PDQ hash strings to a numpy array.

Args:
pdq_strings (Sequence[str]): A sequence of 64-character hexadecimal PDQ hash strings

Returns:
np.ndarray: A 2D array of shape (n_queries, 256) where each row is the full PDQ hash as a bit array
"""
hash_arrays = []
for pdq_str in pdq_strings:
try:
# Convert hex string to integer
hash_int = int(pdq_str, 16)
# Convert to binary string, padding to ensure 256 bits
binary_str = format(hash_int, "0256b")
# Convert to numpy array
hash_array = np.array(
[int(bit) for bit in binary_str], dtype=np.float32
)
hash_arrays.append(hash_array)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid PDQ hash string: {pdq_str}") from e

# Convert list of arrays to a single 2D array
return np.array(hash_arrays, dtype=np.float32)


Self = t.TypeVar("Self", bound="SignalTypeIndex2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?


PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]


class SignalTypeIndex2(t.Generic[T]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blockin q: Why are we creating this class? Issue #1613 is about creating a new PDQ index, but this is creating a new interface for the overall index class which assumes faiss compatibility, which may not be true for every signal type.

def __init__(
self,
threshold: int = DEFAULT_MATCH_DIST,
faiss_index: t.Optional[faiss.Index] = None,
) -> None:
"""
Initialize the PDQ index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refers to PDQ, but you've put it at the top level class.


Args:
threshold (int): The maximum distance threshold for matches.
faiss_index (faiss.Index): An optional pre-existing FAISS index to use.
"""
super().__init__()
if faiss_index is None:
# Use a simple brute-force FAISS index by default
faiss_index = faiss.IndexFlatL2(DIMENSIONALITY)
self.faiss_index = _PDQHashIndex(faiss_index)
self.threshold = threshold
self._deduper: t.Set[str] = set()
Copy link
Contributor

@Dcallies Dcallies Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the meeting:

Should store this as a mapping from hash => faiss_id

add(h, payload1)
add(h, playload2)

deduper: dict[hash, faiss_id] = dict[]
entries: list[int, payload] = list[]

existing_id = deduper.get(new_hash)
if existing_id is not None: 
   # Don't add to faiss!
   # we add to our internal entry mapping
   entries[existing_id].append(payload2)
else:
  # faiss id is 0 -> size
  next_id = len(deduper)
  faiss.add(h)
  entries.append([payload])
  deduper[h] = next_id

///
lookup(h) -> payloads
  faiss_id = faiss.search(h) 
  entries[faiss_id]

query(h) -> [payload1, payload2]

self._entries: t.List[t.List[T]] = []

def query(self, query: str) -> t.List[PDQIndexMatch[T]]:
results = self.faiss_index.search([query], self.threshold)
return [
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(distance=int(distf)), entry
)
for idx, distf in results[0]
for entry in self._entries[idx]
]

def add(self, pdq_hash: str, entry: T) -> None:
"""
Add a PDQ hash and its associated entry to the index.

Args:
pdq_hash (str): The PDQ hash string
entry (T): The associated entry data
"""
if pdq_hash not in self._deduper:
self._deduper.add(pdq_hash)
self.faiss_index.add([pdq_hash])
self._entries.append([entry])
else:
# If hash exists, append entry to existing entries
idx = list(self._deduper).index(pdq_hash)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very expensive call! O(n) to convert _deduper into a list, then O(n) to find which index is in the list. I also don't think it will be stable, since the set is not ordered!

self._entries[idx].append(entry)

def serialize(self, fout: t.BinaryIO) -> None:
"""
Serialize the PDQ index to a binary stream.
"""
fout.write(pickle.dumps(self))

@classmethod
def deserialize(cls, fin: t.BinaryIO) -> "SignalTypeIndex2[T]":
"""
Deserialize a PDQ index from a binary stream.
"""
return pickle.loads(fin.read())
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest
import io
import faiss
from threatexchange.signal_type.index2 import (
SignalTypeIndex2,
_PDQHashIndex,
DIMENSIONALITY,
DEFAULT_MATCH_DIST,
)


@pytest.fixture
def empty_index():
"""Fixture for an empty index."""
return SignalTypeIndex2[str]()


@pytest.fixture
def custom_index_with_threshold():
"""Fixture for an index with custom index and threshold."""
custom_index = faiss.IndexFlatL2(DIMENSIONALITY + 1)
custom_threshold = DEFAULT_MATCH_DIST + 1
return SignalTypeIndex2[str](faiss_index=custom_index, threshold=custom_threshold)


@pytest.fixture
def sample_index():
"""Fixture for an index with a small sample set."""
index = SignalTypeIndex2[str]()
pdq_hashes = [
"f" * 64, # All f's
"0" * 64, # All 0's
"a" * 64, # All a's
"f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", # Sample hash
]
return index, pdq_hashes


def test_init(empty_index) -> None:
assert empty_index.threshold == DEFAULT_MATCH_DIST
assert isinstance(empty_index.faiss_index, _PDQHashIndex)
assert isinstance(empty_index.faiss_index.faiss_index, faiss.IndexFlatL2)
assert empty_index.faiss_index.faiss_index.d == DIMENSIONALITY
assert empty_index._deduper == set()
assert empty_index._entries == []


def test_serialize_deserialize(empty_index) -> None:
buffer = io.BytesIO()
empty_index.serialize(buffer)
buffer.seek(0)
deserialized_index: SignalTypeIndex2[str] = SignalTypeIndex2.deserialize(buffer)

assert isinstance(deserialized_index, SignalTypeIndex2)
assert deserialized_index.threshold == empty_index.threshold
assert isinstance(deserialized_index.faiss_index, _PDQHashIndex)
assert isinstance(deserialized_index.faiss_index.faiss_index, faiss.IndexFlatL2)
assert deserialized_index.faiss_index.faiss_index.d == DIMENSIONALITY
assert deserialized_index._deduper == empty_index._deduper
assert deserialized_index._entries == empty_index._entries


def test_serialize_deserialize_with_custom_index_threshold(
custom_index_with_threshold,
) -> None:
buffer = io.BytesIO()
custom_index_with_threshold.serialize(buffer)
buffer.seek(0)
deserialized_index: SignalTypeIndex2[int] = SignalTypeIndex2.deserialize(buffer)

assert isinstance(deserialized_index, SignalTypeIndex2)
assert deserialized_index.threshold == custom_index_with_threshold.threshold
assert isinstance(deserialized_index.faiss_index, _PDQHashIndex)
assert isinstance(deserialized_index.faiss_index.faiss_index, faiss.IndexFlatL2)
assert deserialized_index.faiss_index.faiss_index.d == DIMENSIONALITY + 1
assert deserialized_index._deduper == custom_index_with_threshold._deduper
assert deserialized_index._entries == custom_index_with_threshold._entries


def test_empty_index_query(empty_index):
"""Test querying an empty index."""
query_hash = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22"

# Query should return empty list
results = empty_index.query(query=query_hash)
assert len(results) == 0


def test_sample_set_exact_match(sample_index):
"""Test exact matches in sample set."""
index, pdq_hashes = sample_index

# Add hashes to index
for hash_str in pdq_hashes:
index.add(hash_str, hash_str) # Using hash as its own identifier

# Query with existing hash
query_hash = pdq_hashes[0]
results = index.query(query_hash)

assert len(results) == 1
assert (
results[0].similarity_info.distance == 0
) # Exact match should have distance 0


def test_sample_set_near_match(sample_index):
"""Test near matches in sample set."""
index, pdq_hashes = sample_index

# Add hashes to index
for hash_str in pdq_hashes:
index.add(hash_str, hash_str) # Using hash as its own identifier

# Create a near-match by flipping a few bits
base_hash = pdq_hashes[0]
near_hash = hex(int(base_hash, 16) ^ 0xF)[2:].zfill(64) # Flip 4 bits

results = index.query(near_hash)
assert len(results) > 0 # Should find near matches
assert results[0].similarity_info.distance > 0


def test_sample_set_threshold(sample_index):
"""Test distance threshold behavior."""
_, pdq_hashes = sample_index

narrow_index = SignalTypeIndex2[str](threshold=10) # Strict matching
wide_index = SignalTypeIndex2[str](threshold=50) # Loose matching

for hash_str in pdq_hashes:
narrow_index.add(hash_str, hash_str)
wide_index.add(hash_str, hash_str)

# Create a test hash with known distance
base_hash = pdq_hashes[0]
test_hash = hex(int(base_hash, 16) ^ ((1 << 20) - 1))[2:].zfill(
64
) # ~20 bits different

narrow_results = narrow_index.query(test_hash)
wide_results = wide_index.query(test_hash)

assert len(wide_results) > len(narrow_results) # Wide threshold should match more


def test_duplicate_handling(sample_index):
"""Test how the index handles duplicate entries."""
index, pdq_hashes = sample_index

# Add same hash multiple times
test_hash = pdq_hashes[0]
for i in range(3):
index.add(test_hash, f"entry_{i}")

results = index.query(test_hash)

# Should find all entries associated with the hash
assert len(results) == 3