Skip to content

Commit

Permalink
[omm] Seed data fixups (facebook#1506)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored Jan 11, 2024
1 parent e9d0696 commit 8cd9601
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 36 deletions.
37 changes: 5 additions & 32 deletions open-media-match/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui
from OpenMediaMatch.storage.interface import BankConfig
from OpenMediaMatch.utils import dev_utils


def _is_debug_mode():
Expand Down Expand Up @@ -177,18 +178,9 @@ def site_map():
return routes

@app.cli.command("seed")
def seed_data():
"""Insert plausible-looking data into the database layer"""
from threatexchange.signal_type.pdq.signal import PdqSignal

bank_name = "SEED_BANK"

storage = get_storage()
storage.bank_update(BankConfig(name=bank_name, matching_enabled_ratio=1.0))

for st in (PdqSignal, VideoMD5Signal):
for example in st.get_examples():
storage.bank_add_content(bank_name, {st.get_name(): example})
def seed_data() -> None:
"""Add sample data API connection"""
dev_utils.seed_sample()

@app.cli.command("big-seed")
@click.option("-b", "--banks", default=100, show_default=True)
Expand All @@ -198,26 +190,7 @@ def seed_enourmous(banks: int, seeds: int) -> None:
Seed the database with a large number of banks and hashes
It will generate n banks and put n/m hashes on each bank
"""
storage = get_storage()

types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal]

for i in range(banks):
# create bank
bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0)
storage.bank_update(bank, create=True)

# Add hashes
for _ in range(seeds // banks):
# grab randomly either PDQ or MD5 signal
signal_type = random.choice(types)
random_hash = signal_type.get_random_signal()

storage.bank_add_content(
bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash}
)

print("Finished adding hashes to", bank.name)
dev_utils.seed_banks_random(banks, seeds)

@app.cli.command("fetch")
def fetch():
Expand Down
13 changes: 13 additions & 0 deletions open-media-match/src/OpenMediaMatch/blueprints/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from OpenMediaMatch.blueprints import matching, curation, hashing
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.utils import dev_utils
from OpenMediaMatch.storage.postgres.flask_utils import reset_tables
from OpenMediaMatch.storage.postgres.database import db
from OpenMediaMatch.utils.time_utils import duration_to_human_str
Expand Down Expand Up @@ -124,6 +125,18 @@ def upload():
return {"hashes": signals, "banks": sorted(banks)}


@bp.route("/seed_sample", methods=["POST"])
def seed_sample():
dev_utils.seed_sample()
return redirect("./")


@bp.route("/seed_banks", methods=["POST"])
def seed_banks():
dev_utils.seed_banks_random()
return redirect("./")


@bp.route("/factory_reset", methods=["POST"])
def factory_reset():
reset_tables()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
<div class="alert alert-primary" role="alert">
<div class="d-flex flex-row align-items-center">
<div class="me-2">Running server in development mode!</div>
<div class="me-2">
<form action="/ui/factory_reset" method="post" enctype="multipart/form-data">
<button type="submit" class="btn btn-danger">Factory Reset</button>
<div class="hstack gap-2">
<form action="/ui/seed_sample" method="post">
<button type="submit" class="btn btn-primary">Seed Sample API</button>
</form>
<form action="/ui/seed_banks" method="post">
<button type="submit" class="btn btn-primary">Seed Banks</button>
</form>
<form action="/ui/factory_reset" method="post">
<button type="submit" class="btn btn-outline-danger">Factory Reset</button>
</form>
</div>
</div>
</div>
</div>
</div>
48 changes: 48 additions & 0 deletions open-media-match/src/OpenMediaMatch/utils/dev_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import typing as t

from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.exchanges.collab_config import CollaborationConfigBase
from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI
from threatexchange.signal_type.signal_base import SignalType, CanGenerateRandomSignal

from OpenMediaMatch import persistence
from OpenMediaMatch.storage.interface import BankConfig


def seed_sample() -> None:
storage = persistence.get_storage()
storage.exchange_update(
CollaborationConfigBase(
name="SEED_SAMPLE",
api=StaticSampleSignalExchangeAPI.get_name(),
enabled=True,
),
create=True,
)


def seed_banks_random(banks: int = 2, seeds: int = 10000) -> None:
"""
Seed the database with a large number of banks and hashes
It will generate n banks and put n/m hashes on each bank
"""
storage = persistence.get_storage()

types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal]

for i in range(banks):
# create bank
bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0)
storage.bank_update(bank, create=True)

# Add hashes
for i in range(seeds // banks):
signal_type = types[i % len(types)]
random_hash = signal_type.get_random_signal()

storage.bank_add_content(
bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash}
)

0 comments on commit 8cd9601

Please sign in to comment.