Skip to content

Commit

Permalink
[hma] add recovery from lobj failure (#1674)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored Oct 30, 2024
1 parent 911f83a commit 5c47b24
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 19 deletions.
1 change: 1 addition & 0 deletions hasher-matcher-actioner/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ all = [
"pytest",
"types-Flask-Migrate",
"types-requests",
"types-psycopg2",
"types-python-dateutil",
"gunicorn",
"flask_apscheduler"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,15 @@ def reload_if_needed(self, store: interface.IUnifiedStore) -> None:
curr_checkpoint = store.get_last_index_build_checkpoint(self.signal_type)
if curr_checkpoint is not None and self.checkpoint != curr_checkpoint:
new_index = store.get_signal_type_index(self.signal_type)
assert new_index is not None
if new_index is None:
app: Flask = get_apscheduler().app
app.logger.error(
"CachedIndex[%s] index checkpoint(%r)"
+ " says new index available but unable to get it",
self.signal_type.get_name(),
curr_checkpoint,
)
return
self.index = new_index
self.checkpoint = curr_checkpoint
self.last_check_ts = now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from flask import current_app
import flask_sqlalchemy
import psycopg2
from sqlalchemy import (
String,
Text,
Expand All @@ -31,6 +32,7 @@
UniqueConstraint,
BigInteger,
event,
text,
)
from sqlalchemy.dialects.postgresql import OID
from sqlalchemy.orm import (
Expand All @@ -43,6 +45,7 @@
from sqlalchemy.types import DateTime
from sqlalchemy.sql import func


from threatexchange.exchanges.collab_config import CollaborationConfigBase
from threatexchange.exchanges import auth
from threatexchange.exchanges.signal_exchange_api import (
Expand Down Expand Up @@ -374,6 +377,23 @@ class SignalIndex(db.Model): # type: ignore[name-defined]

serialized_index_large_object_oid: Mapped[int | None] = mapped_column(OID)

def index_lobj_exists(self) -> bool:
"""
Return true if the index lobj exists and load_signal_index should work.
In normal operation, this should always return true. However,
we've observed in github.com/facebook/ThreatExchange/issues/1673
that some partial failure is possible. This can be used to
detect that condition.
"""
count = db.session.execute(
text(
"SELECT count(1) FROM pg_largeobject_metadata "
+ f"WHERE oid = {self.serialized_index_large_object_oid};"
)
).scalar_one()
return count == 1

def commit_signal_index(
self, index: SignalTypeIndex[int], checkpoint: SignalTypeIndexBuildCheckpoint
) -> t.Self:
Expand Down Expand Up @@ -403,9 +423,16 @@ def commit_signal_index(
duration_to_human_str(int(time.time() - store_start_time)),
)
if self.serialized_index_large_object_oid is not None:
old_obj = raw_conn.lobject(self.serialized_index_large_object_oid, "n") # type: ignore[attr-defined]
self._log("deallocating old lobject %d", old_obj.oid)
old_obj.unlink()
if self.index_lobj_exists():
old_obj = raw_conn.lobject(self.serialized_index_large_object_oid, "n") # type: ignore[attr-defined]
self._log("deallocating old lobject %d", old_obj.oid)
old_obj.unlink()
else:
self._log(
"old lobject %d doesn't exist? "
+ "This might be a previous partial failure",
self.serialized_index_large_object_oid,
)

self.serialized_index_large_object_oid = l_obj.oid
db.session.add(self)
Expand Down
25 changes: 10 additions & 15 deletions hasher-matcher-actioner/src/OpenMediaMatch/storage/postgres/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def get_signal_type_index(
)
).scalar_one_or_none()

return db_record.load_signal_index() if db_record is not None else None
if db_record is None or not db_record.index_lobj_exists():
return None
return db_record.load_signal_index()

def store_signal_type_index(
self,
Expand All @@ -213,22 +215,15 @@ def store_signal_type_index(
def get_last_index_build_checkpoint(
self, signal_type: t.Type[SignalType]
) -> t.Optional[interface.SignalTypeIndexBuildCheckpoint]:
row = database.db.session.execute(
select(
database.SignalIndex.updated_to_ts,
database.SignalIndex.updated_to_id,
database.SignalIndex.signal_count,
).where(database.SignalIndex.signal_type == signal_type.get_name())
).one_or_none()
db_record = database.db.session.execute(
select(database.SignalIndex).where(
database.SignalIndex.signal_type == signal_type.get_name()
)
).scalar_one_or_none()

if row is None:
if db_record is None or not db_record.index_lobj_exists():
return None
updated_to_ts, updated_to_id, total_count = row._tuple()
return interface.SignalTypeIndexBuildCheckpoint(
last_item_timestamp=updated_to_ts,
last_item_id=updated_to_id,
total_hash_count=total_count,
)
return db_record.as_checkpoint()

# Collabs
def exchange_update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,51 @@ def assert_content():
assert md5_index_status.total_hash_count == maker.count


def test_recover_from_index_unlink_partial_failure(storage: DefaultOMMStore):
"""
SignalTypeIndex is stored in the postgres large object interface.
As part of swapping over to the new index, we need to unlink the old
interface, but there is a race here because unlinking may not happen
in a transaction. See github.com/facebook/ThreatExchange/issues/1673
"""
bank_cfg = interface.BankConfig("TEST", matching_enabled_ratio=1.0)
storage.bank_update(bank_cfg, create=True)
storage.bank_add_content(
bank_cfg.name, {VideoMD5Signal: VideoMD5Signal.get_examples()[0]}
)

def build_and_assert_ok():
build(storage)
index_status = storage.get_last_index_build_checkpoint(VideoMD5Signal)
assert index_status.total_hash_count == 1
index = storage.get_signal_type_index(VideoMD5Signal)
assert index is not None

build_and_assert_ok()

# Now we'll fake that the large object borked
index_record = database.db.session.execute(
select(database.SignalIndex).where(
database.SignalIndex.signal_type == VideoMD5Signal.get_name()
)
).scalar_one()
assert index_record.index_lobj_exists() is True
raw_conn = database.db.engine.raw_connection()
old_obj = raw_conn.lobject(index_record.serialized_index_large_object_oid, "n") # type: ignore[attr-defined]
old_obj.unlink()
raw_conn.commit()

# Now we should have a dangling record
index_status = storage.get_last_index_build_checkpoint(VideoMD5Signal)
assert index_status is None
index = storage.get_signal_type_index(VideoMD5Signal)
assert index is None

# We should be able to recover by rebuilding
build_and_assert_ok()


class _UnknownSampleExchangeAPI(StaticSampleSignalExchangeAPI):
"""Returns all the sample data, but can't convert to any types"""

Expand Down

0 comments on commit 5c47b24

Please sign in to comment.